Merge 094eaef7b31f5c0bfd6c200a972ad332be374fcf into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
yugong333 2025-12-25 00:06:38 +00:00 committed by GitHub
commit f0f9925b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 191 additions and 43 deletions

View File

@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple):
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)

View File

@ -1,17 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
logger = init_logger(__name__)
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@ -81,6 +85,7 @@ def _fused_moe_lora_kernel(
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
max_loras: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
@ -104,7 +109,7 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
max_loras = tl.num_programs(axis=2)
# max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
@ -228,6 +233,7 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
) -> None:
w1_lora_a_stacked = lora_a_stacked[0]
@ -251,7 +257,7 @@ def _fused_moe_lora_shrink(
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
num_active_loras,
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
@ -280,6 +286,7 @@ def _fused_moe_lora_shrink(
expert_ids.stride(0),
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
max_loras=lora_a_stacked[0].shape[0],
num_slice_a=1,
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
@ -322,6 +329,7 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
) -> None:
@ -351,7 +359,7 @@ def _fused_moe_lora_expand(
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
num_active_loras,
)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
@ -382,6 +390,7 @@ def _fused_moe_lora_expand(
slice_c_size=b_intermediate_cache1.numel() // num_slices,
num_slice_a=num_slices,
num_slice_c=num_slices,
max_loras=lora_b_stacked[0].shape[0],
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
IS_PRIMARY=False,
@ -408,6 +417,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@ -492,6 +502,7 @@ def _fused_moe_lora(
shrink_num_warps,
shrink_num_stages,
shrink_split_k,
num_active_loras,
mul_routed_weight,
)
@ -538,6 +549,7 @@ def _fused_moe_lora(
expand_num_warps,
expand_num_stages,
expand_split_k,
num_active_loras,
mul_routed_weight,
offset,
)
@ -555,6 +567,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@ -601,6 +614,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
) -> None:
return
@ -634,6 +648,7 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
) -> None:
return

View File

@ -140,6 +140,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
@ -239,7 +240,7 @@ def _lora_expand(
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
num_active_loras,
)
use_gdc = supports_pdl(inputs.device)
_lora_expand_kernel[grid](
@ -291,6 +292,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:

View File

@ -28,6 +28,10 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
@staticmethod
def make(
max_loras: int, max_num_tokens: int, device: torch.device | str
@ -73,6 +77,7 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
@ -117,6 +122,9 @@ class LoRAKernelMeta:
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True
)
# Store num_active_loras as Python int (excludes -1 which means no LoRA)
# Valid LoRA IDs are >= 0, so count those
self.num_active_loras = int((lora_ids >= 0).sum().item())
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
@ -133,6 +141,7 @@ class LoRAKernelMeta:
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
]:
"""
This function returns the kernel metadata required for the current
@ -151,4 +160,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras,
)

View File

@ -136,6 +136,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
scaling: float,
) -> None:
"""
@ -219,7 +220,7 @@ def _lora_shrink(
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
num_active_loras,
)
use_gdc = supports_pdl(inputs.device)
_lora_shrink_kernel[grid](
@ -269,6 +270,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
scaling: float,
) -> None:
return

View File

@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(num_tokens)
)
ops.moe_lora_align_block_size(
@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
(_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args(
x.size(0)
)
fused_moe_lora(
y,
x,
@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),

View File

@ -57,9 +57,33 @@ class CudagraphDispatcher:
)
self.keys_initialized = False
self.specialize_lora_count = False
def _get_lora_cases(self) -> list[tuple[bool, int]]:
"""
Returns list of (has_lora, num_active_loras) tuples for CUDA graph
capture. This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [(False, 0)]
# LoRA is enabled - capture graphs for different active LoRA counts
# Always include the no-LoRA case (for requests without adapters)
cases: list[tuple[bool, int]] = [(False, 0)]
for n in range(1, lora_config.max_loras + 1):
cases.append((True, n))
return cases
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
@ -77,6 +101,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
def add_cudagraph_key(
@ -94,26 +119,23 @@ class CudagraphDispatcher:
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Track whether we have LoRA config (always specialize on count)
self.has_lora_config = self.vllm_config.lora_config is not None
# Get LoRA cases to capture
lora_cases = self._get_lora_cases()
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product(
for bs, (has_lora, num_active_loras) in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, has_lora
bs, False, has_lora, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
)
@ -132,10 +154,14 @@ class CudagraphDispatcher:
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
for bs, (has_lora, num_active_loras) in product(
cudagraph_capture_sizes_for_decode, lora_cases
):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora),
self._create_padded_batch_descriptor(
bs, True, has_lora, num_active_loras
),
)
self.keys_initialized = True
@ -146,12 +172,20 @@ class CudagraphDispatcher:
uniform_decode: bool,
has_lora: bool,
disable_full: bool = False,
num_active_loras: int = 0,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether this is a uniform decode batch.
has_lora: Whether this batch has active LoRA adapters.
disable_full: Whether to disable full cudagraph mode.
num_active_loras: Number of distinct active LoRA adapters.
"""
if (
not self.keys_initialized
@ -160,8 +194,10 @@ class CudagraphDispatcher:
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
effective_num_active_loras = num_active_loras
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

View File

@ -2814,6 +2814,7 @@ class GPUModelRunner(
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
force_num_active_loras: int | None = None,
num_encoder_reqs: int = 0,
) -> tuple[
CUDAGraphMode,
@ -2835,11 +2836,13 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
)
has_lora = (
len(self.input_batch.lora_id_to_lora_request) > 0
if force_has_lora is None
else force_has_lora
# Compute LoRA state for cudagraph dispatch
num_active_loras = (
force_num_active_loras
if force_num_active_loras is not None
else len(self.input_batch.lora_id_to_lora_request)
)
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = (
@ -2848,6 +2851,7 @@ class GPUModelRunner(
has_lora=has_lora,
uniform_decode=uniform_decode,
disable_full=disable_full,
num_active_loras=num_active_loras,
)
if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
@ -4052,6 +4056,7 @@ class GPUModelRunner(
remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False,
num_active_loras: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@ -4075,6 +4080,8 @@ class GPUModelRunner(
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
num_active_loras: Number of distinct active LoRAs to capture for.
Only used when cudagraph_specialize_lora_count is True.
"""
if supports_mm_encoder_only(self.model):
# The current dummy run only covers LM execution, so we can skip it.
@ -4156,6 +4163,9 @@ class GPUModelRunner(
# activated later in the context manager, but we need to know the
# LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
# `force_num_active_loras` is used for cudagraph capture; because we
# need to capture graphs for specific num_active_loras counts
force_num_active_loras=num_active_loras if activate_lora else 0,
)
)
@ -4219,6 +4229,7 @@ class GPUModelRunner(
num_sampled_tokens,
activate_lora,
remove_lora,
num_active_loras,
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens
@ -4618,13 +4629,8 @@ class GPUModelRunner(
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Build LoRA cases: list of (has_lora, num_active_loras) tuples
lora_cases = self.cudagraph_dispatcher._get_lora_cases()
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
@ -4689,10 +4695,23 @@ class GPUModelRunner(
def _capture_cudagraphs(
self,
compilation_cases: list[tuple[int, bool]],
compilation_cases: list[tuple[int, tuple[bool, int]]],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
"""
Capture CUDA graphs for the given compilation cases.
Args:
compilation_cases: List of (num_tokens, (has_lora, num_active_loras))
tuples.
- num_tokens: batch size to capture
- has_lora: whether LoRA is active for this capture
- num_active_loras: number of distinct active LoRAs
(0 if not specializing)
cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode.
uniform_decode: Whether this is a uniform decode batch.
"""
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes()
@ -4710,7 +4729,7 @@ class GPUModelRunner(
)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens, activate_lora in compilation_cases:
for num_tokens, (activate_lora, num_active_loras) in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
@ -4742,6 +4761,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
num_active_loras=num_active_loras,
)
self._dummy_run(
num_tokens,
@ -4751,6 +4771,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
num_active_loras=num_active_loras,
is_graph_capturing=True,
)
self.maybe_remove_all_loras(self.lora_config)

View File

@ -4,7 +4,10 @@
Define LoRA functionality mixin for model runners.
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, TypeAlias
import numpy as np
import torch
@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = TPUInputBatch | GPUInputBatch
if TYPE_CHECKING:
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
logger = init_logger(__name__)
@ -129,7 +133,20 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True,
num_active_loras: int = 0,
):
"""
Context manager to select dummy LoRAs for capture/warmup.
Args:
lora_config: LoRA configuration, or None if LoRA is disabled.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
activate_lora: Whether to activate LoRAs (False means no LoRA).
num_active_loras: Number of distinct active LoRAs to use.
- 0: Use all max_loras (default behavior for has_lora=True).
- >0: Use exactly this many distinct LoRAs.
"""
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
@ -140,13 +157,24 @@ class LoRAModelRunnerMixin:
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
max_loras = lora_config.max_loras
# Determine how many distinct LoRAs to use
if not activate_lora:
# No LoRA active
effective_num_loras = 0
elif num_active_loras > 0:
# Specific number of active LoRAs requested
effective_num_loras = min(num_active_loras, max_loras)
else:
# Default: use all max_loras
effective_num_loras = max_loras
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
if activate_lora:
if effective_num_loras > 0:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
) + 1
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
@ -157,16 +185,19 @@ class LoRAModelRunnerMixin:
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests
# Make dummy lora requests (only for the active LoRAs)
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
for lora_id in range(1, effective_num_loras + 1)
}
# Always call _set_active_loras to ensure the mapping is updated.
# This is important when capturing no-LoRA graphs (effective_num_loras=0)
# after capturing LoRA graphs, as we need to clear the previous mapping.
self._set_active_loras(
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
)
@ -181,11 +212,27 @@ class LoRAModelRunnerMixin:
num_sampled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
num_active_loras: int = 0,
):
"""
Context manager for dummy runs with LoRA.
Args:
lora_config: LoRA configuration.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
activate_lora: Whether to activate LoRAs.
remove_lora: Whether to remove LoRAs after the context exits.
num_active_loras: Number of distinct active LoRAs to use.
"""
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
lora_config,
num_scheduled_tokens,
num_sampled_tokens,
activate_lora,
num_active_loras,
),
):
yield