mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 06:57:02 +08:00
Merge 094eaef7b31f5c0bfd6c200a972ad332be374fcf into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
f0f9925b5c
@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
Whether this batch has active LoRA adapters.
|
||||
"""
|
||||
num_active_loras: int = 0
|
||||
"""
|
||||
Number of distinct active LoRA adapters in this batch.
|
||||
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
|
||||
are captured for each num_active_loras value. This allows kernels
|
||||
(like fused_moe_lora) whose grid size depends on num_active_loras
|
||||
to be properly captured.
|
||||
"""
|
||||
|
||||
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
||||
"""
|
||||
@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple):
|
||||
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
||||
"""
|
||||
return BatchDescriptor(
|
||||
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
|
||||
self.num_tokens,
|
||||
num_reqs=None,
|
||||
uniform=False,
|
||||
has_lora=self.has_lora,
|
||||
num_active_loras=self.num_active_loras,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
|
||||
|
||||
@ -81,6 +85,7 @@ def _fused_moe_lora_kernel(
|
||||
# Meta-parameters
|
||||
num_slice_a: tl.constexpr,
|
||||
num_slice_c: tl.constexpr,
|
||||
max_loras: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
@ -104,7 +109,7 @@ def _fused_moe_lora_kernel(
|
||||
if moe_enabled == 0:
|
||||
# Early exit for the no moe lora case.
|
||||
return
|
||||
max_loras = tl.num_programs(axis=2)
|
||||
# max_loras = tl.num_programs(axis=2)
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
@ -228,6 +233,7 @@ def _fused_moe_lora_shrink(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
@ -251,7 +257,7 @@ def _fused_moe_lora_shrink(
|
||||
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_a_stacked),
|
||||
lora_a_stacked[0].shape[0],
|
||||
num_active_loras,
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
qcurr_hidden_states,
|
||||
@ -280,6 +286,7 @@ def _fused_moe_lora_shrink(
|
||||
expert_ids.stride(0),
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
||||
max_loras=lora_a_stacked[0].shape[0],
|
||||
num_slice_a=1,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
@ -322,6 +329,7 @@ def _fused_moe_lora_expand(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
@ -351,7 +359,7 @@ def _fused_moe_lora_expand(
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_b_stacked),
|
||||
lora_b_stacked[0].shape[0],
|
||||
num_active_loras,
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
@ -382,6 +390,7 @@ def _fused_moe_lora_expand(
|
||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
max_loras=lora_b_stacked[0].shape[0],
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
IS_PRIMARY=False,
|
||||
@ -408,6 +417,7 @@ def _fused_moe_lora(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@ -492,6 +502,7 @@ def _fused_moe_lora(
|
||||
shrink_num_warps,
|
||||
shrink_num_stages,
|
||||
shrink_split_k,
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
@ -538,6 +549,7 @@ def _fused_moe_lora(
|
||||
expand_num_warps,
|
||||
expand_num_stages,
|
||||
expand_split_k,
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
)
|
||||
@ -555,6 +567,7 @@ def _fused_moe_lora_fake(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@ -601,6 +614,7 @@ def _fused_moe_lora_shrink_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
@ -634,6 +648,7 @@ def _fused_moe_lora_expand_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@ -140,6 +140,7 @@ def _lora_expand(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
@ -239,7 +240,7 @@ def _lora_expand(
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
num_active_loras,
|
||||
)
|
||||
use_gdc = supports_pdl(inputs.device)
|
||||
_lora_expand_kernel[grid](
|
||||
@ -291,6 +292,7 @@ def _lora_expand_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
|
||||
@ -28,6 +28,10 @@ class LoRAKernelMeta:
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
|
||||
# Stored as a Python int to avoid GPU->CPU sync during forward pass
|
||||
num_active_loras: int = 0
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
max_loras: int, max_num_tokens: int, device: torch.device | str
|
||||
@ -73,6 +77,7 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
self.num_active_loras = 0
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
@ -117,6 +122,9 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
|
||||
num_tokens_per_lora, non_blocking=True
|
||||
)
|
||||
# Store num_active_loras as Python int (excludes -1 which means no LoRA)
|
||||
# Valid LoRA IDs are >= 0, so count those
|
||||
self.num_active_loras = int((lora_ids >= 0).sum().item())
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
@ -133,6 +141,7 @@ class LoRAKernelMeta:
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
int,
|
||||
]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
@ -151,4 +160,5 @@ class LoRAKernelMeta:
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
self.num_active_loras,
|
||||
)
|
||||
|
||||
@ -136,6 +136,7 @@ def _lora_shrink(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
@ -219,7 +220,7 @@ def _lora_shrink(
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
num_active_loras,
|
||||
)
|
||||
use_gdc = supports_pdl(inputs.device)
|
||||
_lora_shrink_kernel[grid](
|
||||
@ -269,6 +270,7 @@ def _lora_shrink_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
|
||||
num_tokens
|
||||
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
|
||||
self.token_mapping_meta.meta_args(num_tokens)
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
|
||||
(_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args(
|
||||
x.size(0)
|
||||
)
|
||||
fused_moe_lora(
|
||||
y,
|
||||
x,
|
||||
@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||
|
||||
@ -57,9 +57,33 @@ class CudagraphDispatcher:
|
||||
)
|
||||
|
||||
self.keys_initialized = False
|
||||
self.specialize_lora_count = False
|
||||
|
||||
def _get_lora_cases(self) -> list[tuple[bool, int]]:
|
||||
"""
|
||||
Returns list of (has_lora, num_active_loras) tuples for CUDA graph
|
||||
capture. This is the single source of truth for LoRA capture cases.
|
||||
"""
|
||||
lora_config = self.vllm_config.lora_config
|
||||
if lora_config is None:
|
||||
# No LoRA configured - single case with no LoRA
|
||||
return [(False, 0)]
|
||||
|
||||
# LoRA is enabled - capture graphs for different active LoRA counts
|
||||
# Always include the no-LoRA case (for requests without adapters)
|
||||
cases: list[tuple[bool, int]] = [(False, 0)]
|
||||
|
||||
for n in range(1, lora_config.max_loras + 1):
|
||||
cases.append((True, n))
|
||||
|
||||
return cases
|
||||
|
||||
def _create_padded_batch_descriptor(
|
||||
self, num_tokens: int, uniform_decode: bool, has_lora: bool
|
||||
self,
|
||||
num_tokens: int,
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
num_active_loras: int = 0,
|
||||
) -> BatchDescriptor:
|
||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
uniform_decode_query_len = self.uniform_decode_query_len
|
||||
@ -77,6 +101,7 @@ class CudagraphDispatcher:
|
||||
num_reqs=num_reqs,
|
||||
uniform=uniform_decode,
|
||||
has_lora=has_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
|
||||
def add_cudagraph_key(
|
||||
@ -94,26 +119,23 @@ class CudagraphDispatcher:
|
||||
# get the correct cudagraph mode after backend support is resolved.
|
||||
self.cudagraph_mode = cudagraph_mode
|
||||
|
||||
# LoRA activation cases to specialize the cuda graphs on
|
||||
if self.vllm_config.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
# Track whether we have LoRA config (always specialize on count)
|
||||
self.has_lora_config = self.vllm_config.lora_config is not None
|
||||
|
||||
# Get LoRA cases to capture
|
||||
lora_cases = self._get_lora_cases()
|
||||
|
||||
# Note: we create all valid keys for cudagraph here but do not
|
||||
# guarantee all keys would be used. For example, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs, has_lora in product(
|
||||
for bs, (has_lora, num_active_loras) in product(
|
||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, False, has_lora
|
||||
bs, False, has_lora, num_active_loras
|
||||
).relax_for_mixed_batch_cudagraphs(),
|
||||
)
|
||||
|
||||
@ -132,10 +154,14 @@ class CudagraphDispatcher:
|
||||
for x in self.compilation_config.cudagraph_capture_sizes
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||
for bs, (has_lora, num_active_loras) in product(
|
||||
cudagraph_capture_sizes_for_decode, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
self._create_padded_batch_descriptor(bs, True, has_lora),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, True, has_lora, num_active_loras
|
||||
),
|
||||
)
|
||||
|
||||
self.keys_initialized = True
|
||||
@ -146,12 +172,20 @@ class CudagraphDispatcher:
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
disable_full: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
|
||||
Args:
|
||||
num_tokens: Number of tokens in the batch.
|
||||
uniform_decode: Whether this is a uniform decode batch.
|
||||
has_lora: Whether this batch has active LoRA adapters.
|
||||
disable_full: Whether to disable full cudagraph mode.
|
||||
num_active_loras: Number of distinct active LoRA adapters.
|
||||
"""
|
||||
if (
|
||||
not self.keys_initialized
|
||||
@ -160,8 +194,10 @@ class CudagraphDispatcher:
|
||||
):
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
effective_num_active_loras = num_active_loras
|
||||
|
||||
batch_desc = self._create_padded_batch_descriptor(
|
||||
num_tokens, uniform_decode, has_lora
|
||||
num_tokens, uniform_decode, has_lora, effective_num_active_loras
|
||||
)
|
||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||
|
||||
|
||||
@ -2814,6 +2814,7 @@ class GPUModelRunner(
|
||||
# be improved in model runner v2)
|
||||
force_uniform_decode: bool | None = None,
|
||||
force_has_lora: bool | None = None,
|
||||
force_num_active_loras: int | None = None,
|
||||
num_encoder_reqs: int = 0,
|
||||
) -> tuple[
|
||||
CUDAGraphMode,
|
||||
@ -2835,11 +2836,13 @@ class GPUModelRunner(
|
||||
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
||||
)
|
||||
|
||||
has_lora = (
|
||||
len(self.input_batch.lora_id_to_lora_request) > 0
|
||||
if force_has_lora is None
|
||||
else force_has_lora
|
||||
# Compute LoRA state for cudagraph dispatch
|
||||
num_active_loras = (
|
||||
force_num_active_loras
|
||||
if force_num_active_loras is not None
|
||||
else len(self.input_batch.lora_id_to_lora_request)
|
||||
)
|
||||
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
|
||||
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
dispatch_cudagraph = (
|
||||
@ -2848,6 +2851,7 @@ class GPUModelRunner(
|
||||
has_lora=has_lora,
|
||||
uniform_decode=uniform_decode,
|
||||
disable_full=disable_full,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
if not force_eager
|
||||
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
||||
@ -4052,6 +4056,7 @@ class GPUModelRunner(
|
||||
remove_lora: bool = True,
|
||||
activate_lora: bool = False,
|
||||
is_graph_capturing: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run a dummy forward pass to warm up/profile run or capture the
|
||||
@ -4075,6 +4080,8 @@ class GPUModelRunner(
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||
num_active_loras: Number of distinct active LoRAs to capture for.
|
||||
Only used when cudagraph_specialize_lora_count is True.
|
||||
"""
|
||||
if supports_mm_encoder_only(self.model):
|
||||
# The current dummy run only covers LM execution, so we can skip it.
|
||||
@ -4156,6 +4163,9 @@ class GPUModelRunner(
|
||||
# activated later in the context manager, but we need to know the
|
||||
# LoRA state when determining the batch descriptor for capture
|
||||
force_has_lora=activate_lora,
|
||||
# `force_num_active_loras` is used for cudagraph capture; because we
|
||||
# need to capture graphs for specific num_active_loras counts
|
||||
force_num_active_loras=num_active_loras if activate_lora else 0,
|
||||
)
|
||||
)
|
||||
|
||||
@ -4219,6 +4229,7 @@ class GPUModelRunner(
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
remove_lora,
|
||||
num_active_loras,
|
||||
):
|
||||
# Make sure padding doesn't exceed max_num_tokens
|
||||
assert num_tokens_padded <= self.max_num_tokens
|
||||
@ -4618,13 +4629,8 @@ class GPUModelRunner(
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
|
||||
if self.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
# Build LoRA cases: list of (has_lora, num_active_loras) tuples
|
||||
lora_cases = self.cudagraph_dispatcher._get_lora_cases()
|
||||
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
@ -4689,10 +4695,23 @@ class GPUModelRunner(
|
||||
|
||||
def _capture_cudagraphs(
|
||||
self,
|
||||
compilation_cases: list[tuple[int, bool]],
|
||||
compilation_cases: list[tuple[int, tuple[bool, int]]],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool,
|
||||
):
|
||||
"""
|
||||
Capture CUDA graphs for the given compilation cases.
|
||||
|
||||
Args:
|
||||
compilation_cases: List of (num_tokens, (has_lora, num_active_loras))
|
||||
tuples.
|
||||
- num_tokens: batch size to capture
|
||||
- has_lora: whether LoRA is active for this capture
|
||||
- num_active_loras: number of distinct active LoRAs
|
||||
(0 if not specializing)
|
||||
cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode.
|
||||
uniform_decode: Whether this is a uniform decode batch.
|
||||
"""
|
||||
assert (
|
||||
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
and cudagraph_runtime_mode.valid_runtime_modes()
|
||||
@ -4710,7 +4729,7 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens, activate_lora in compilation_cases:
|
||||
for num_tokens, (activate_lora, num_active_loras) in compilation_cases:
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
@ -4742,6 +4761,7 @@ class GPUModelRunner(
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
@ -4751,6 +4771,7 @@ class GPUModelRunner(
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
is_graph_capturing=True,
|
||||
)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch = TPUInputBatch | GPUInputBatch
|
||||
if TYPE_CHECKING:
|
||||
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -129,7 +133,20 @@ class LoRAModelRunnerMixin:
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray | None = None,
|
||||
activate_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
"""
|
||||
Context manager to select dummy LoRAs for capture/warmup.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration, or None if LoRA is disabled.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
activate_lora: Whether to activate LoRAs (False means no LoRA).
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
- 0: Use all max_loras (default behavior for has_lora=True).
|
||||
- >0: Use exactly this many distinct LoRAs.
|
||||
"""
|
||||
if num_sampled_tokens is None:
|
||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||
|
||||
@ -140,13 +157,24 @@ class LoRAModelRunnerMixin:
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
max_loras = lora_config.max_loras
|
||||
|
||||
# Determine how many distinct LoRAs to use
|
||||
if not activate_lora:
|
||||
# No LoRA active
|
||||
effective_num_loras = 0
|
||||
elif num_active_loras > 0:
|
||||
# Specific number of active LoRAs requested
|
||||
effective_num_loras = min(num_active_loras, max_loras)
|
||||
else:
|
||||
# Default: use all max_loras
|
||||
effective_num_loras = max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
if activate_lora:
|
||||
if effective_num_loras > 0:
|
||||
prompt_lora_mapping = (
|
||||
np.arange(num_reqs, dtype=np.int32) % num_loras
|
||||
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
|
||||
) + 1
|
||||
else:
|
||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||
@ -157,16 +185,19 @@ class LoRAModelRunnerMixin:
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
# Make dummy lora requests (only for the active LoRAs)
|
||||
lora_requests: set[LoRARequest] = {
|
||||
LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
for lora_id in range(1, effective_num_loras + 1)
|
||||
}
|
||||
|
||||
# Always call _set_active_loras to ensure the mapping is updated.
|
||||
# This is important when capturing no-LoRA graphs (effective_num_loras=0)
|
||||
# after capturing LoRA graphs, as we need to clear the previous mapping.
|
||||
self._set_active_loras(
|
||||
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
||||
)
|
||||
@ -181,11 +212,27 @@ class LoRAModelRunnerMixin:
|
||||
num_sampled_tokens: np.ndarray,
|
||||
activate_lora: bool = True,
|
||||
remove_lora: bool = True,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
"""
|
||||
Context manager for dummy runs with LoRA.
|
||||
|
||||
Args:
|
||||
lora_config: LoRA configuration.
|
||||
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||
num_sampled_tokens: Array of sampled token counts per request.
|
||||
activate_lora: Whether to activate LoRAs.
|
||||
remove_lora: Whether to remove LoRAs after the context exits.
|
||||
num_active_loras: Number of distinct active LoRAs to use.
|
||||
"""
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
|
||||
lora_config,
|
||||
num_scheduled_tokens,
|
||||
num_sampled_tokens,
|
||||
activate_lora,
|
||||
num_active_loras,
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user