mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 03:07:04 +08:00
Capture multiple cuda graph across various active loras
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
parent
288d67d054
commit
ea1a26d276
@ -46,6 +46,14 @@ class BatchDescriptor(NamedTuple):
|
|||||||
"""
|
"""
|
||||||
Whether this batch has active LoRA adapters.
|
Whether this batch has active LoRA adapters.
|
||||||
"""
|
"""
|
||||||
|
num_active_loras: int = 0
|
||||||
|
"""
|
||||||
|
Number of distinct active LoRA adapters in this batch.
|
||||||
|
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
|
||||||
|
are captured for each num_active_loras value. This allows kernels
|
||||||
|
(like fused_moe_lora) whose grid size depends on num_active_loras
|
||||||
|
to be properly captured.
|
||||||
|
"""
|
||||||
|
|
||||||
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
||||||
"""
|
"""
|
||||||
@ -53,7 +61,11 @@ class BatchDescriptor(NamedTuple):
|
|||||||
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
||||||
"""
|
"""
|
||||||
return BatchDescriptor(
|
return BatchDescriptor(
|
||||||
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
|
self.num_tokens,
|
||||||
|
num_reqs=None,
|
||||||
|
uniform=False,
|
||||||
|
has_lora=self.has_lora,
|
||||||
|
num_active_loras=self.num_active_loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,21 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from .utils import supports_pdl
|
from .utils import supports_pdl
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -413,6 +417,7 @@ def _fused_moe_lora(
|
|||||||
max_lora_rank: int,
|
max_lora_rank: int,
|
||||||
top_k_num: int,
|
top_k_num: int,
|
||||||
lora_ids: torch.Tensor,
|
lora_ids: torch.Tensor,
|
||||||
|
num_active_loras: int,
|
||||||
adapter_enabled: torch.Tensor,
|
adapter_enabled: torch.Tensor,
|
||||||
shrink_block_size_m: int,
|
shrink_block_size_m: int,
|
||||||
shrink_block_size_n: int,
|
shrink_block_size_n: int,
|
||||||
@ -562,6 +567,7 @@ def _fused_moe_lora_fake(
|
|||||||
max_lora_rank: int,
|
max_lora_rank: int,
|
||||||
top_k_num: int,
|
top_k_num: int,
|
||||||
lora_ids: torch.Tensor,
|
lora_ids: torch.Tensor,
|
||||||
|
num_active_loras: int,
|
||||||
adapter_enabled: torch.Tensor,
|
adapter_enabled: torch.Tensor,
|
||||||
shrink_block_size_m: int,
|
shrink_block_size_m: int,
|
||||||
shrink_block_size_n: int,
|
shrink_block_size_n: int,
|
||||||
|
|||||||
@ -140,6 +140,7 @@ def _lora_expand(
|
|||||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||||
|
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -291,6 +292,7 @@ def _lora_expand_fake(
|
|||||||
lora_token_start_loc: torch.Tensor,
|
lora_token_start_loc: torch.Tensor,
|
||||||
lora_ids: torch.Tensor,
|
lora_ids: torch.Tensor,
|
||||||
no_lora_flag_cpu: torch.Tensor,
|
no_lora_flag_cpu: torch.Tensor,
|
||||||
|
num_active_loras: int,
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -28,6 +28,10 @@ class LoRAKernelMeta:
|
|||||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||||
no_lora_flag_cpu: torch.Tensor
|
no_lora_flag_cpu: torch.Tensor
|
||||||
|
|
||||||
|
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
|
||||||
|
# Stored as a Python int to avoid GPU->CPU sync during forward pass
|
||||||
|
num_active_loras: int = 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
max_loras: int, max_num_tokens: int, device: torch.device | str
|
max_loras: int, max_num_tokens: int, device: torch.device | str
|
||||||
@ -73,6 +77,7 @@ class LoRAKernelMeta:
|
|||||||
self.num_tokens_per_lora.fill_(0)
|
self.num_tokens_per_lora.fill_(0)
|
||||||
self.lora_token_start_loc.fill_(0)
|
self.lora_token_start_loc.fill_(0)
|
||||||
self.no_lora_flag_cpu.fill_(False)
|
self.no_lora_flag_cpu.fill_(False)
|
||||||
|
self.num_active_loras = 0
|
||||||
|
|
||||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
@ -117,6 +122,9 @@ class LoRAKernelMeta:
|
|||||||
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
|
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
|
||||||
num_tokens_per_lora, non_blocking=True
|
num_tokens_per_lora, non_blocking=True
|
||||||
)
|
)
|
||||||
|
# Store num_active_loras as Python int (excludes -1 which means no LoRA)
|
||||||
|
# Valid LoRA IDs are >= 0, so count those
|
||||||
|
self.num_active_loras = int((lora_ids >= 0).sum().item())
|
||||||
|
|
||||||
# lora_token_start_loc
|
# lora_token_start_loc
|
||||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||||
@ -133,6 +141,7 @@ class LoRAKernelMeta:
|
|||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
int,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
This function returns the kernel metadata required for the current
|
This function returns the kernel metadata required for the current
|
||||||
@ -151,4 +160,5 @@ class LoRAKernelMeta:
|
|||||||
self.lora_token_start_loc,
|
self.lora_token_start_loc,
|
||||||
self.active_lora_ids,
|
self.active_lora_ids,
|
||||||
self.no_lora_flag_cpu,
|
self.no_lora_flag_cpu,
|
||||||
|
self.num_active_loras,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -136,6 +136,7 @@ def _lora_shrink(
|
|||||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||||
|
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||||
scaling: float,
|
scaling: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -269,6 +270,7 @@ def _lora_shrink_fake(
|
|||||||
lora_token_start_loc: torch.Tensor,
|
lora_token_start_loc: torch.Tensor,
|
||||||
lora_ids: torch.Tensor,
|
lora_ids: torch.Tensor,
|
||||||
no_lora_flag_cpu: torch.Tensor,
|
no_lora_flag_cpu: torch.Tensor,
|
||||||
|
num_active_loras: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -333,8 +333,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
|
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
|
||||||
num_tokens
|
self.token_mapping_meta.meta_args(num_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.moe_lora_align_block_size(
|
ops.moe_lora_align_block_size(
|
||||||
@ -378,7 +378,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
"""
|
"""
|
||||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||||
"""
|
"""
|
||||||
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
|
(_, _, _, _, lora_ids, _, num_active_loras) = self.token_mapping_meta.meta_args(
|
||||||
|
x.size(0)
|
||||||
|
)
|
||||||
fused_moe_lora(
|
fused_moe_lora(
|
||||||
y,
|
y,
|
||||||
x,
|
x,
|
||||||
@ -391,6 +393,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
max_lora_rank,
|
max_lora_rank,
|
||||||
top_k_num,
|
top_k_num,
|
||||||
lora_ids,
|
lora_ids,
|
||||||
|
num_active_loras,
|
||||||
adapter_enabled,
|
adapter_enabled,
|
||||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||||
|
|||||||
@ -57,9 +57,14 @@ class CudagraphDispatcher:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.keys_initialized = False
|
self.keys_initialized = False
|
||||||
|
self.specialize_lora_count = False
|
||||||
|
|
||||||
def _create_padded_batch_descriptor(
|
def _create_padded_batch_descriptor(
|
||||||
self, num_tokens: int, uniform_decode: bool, has_lora: bool
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
uniform_decode: bool,
|
||||||
|
has_lora: bool,
|
||||||
|
num_active_loras: int = 0,
|
||||||
) -> BatchDescriptor:
|
) -> BatchDescriptor:
|
||||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
uniform_decode_query_len = self.uniform_decode_query_len
|
uniform_decode_query_len = self.uniform_decode_query_len
|
||||||
@ -77,6 +82,7 @@ class CudagraphDispatcher:
|
|||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
uniform=uniform_decode,
|
uniform=uniform_decode,
|
||||||
has_lora=has_lora,
|
has_lora=has_lora,
|
||||||
|
num_active_loras=num_active_loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_cudagraph_key(
|
def add_cudagraph_key(
|
||||||
@ -87,6 +93,37 @@ class CudagraphDispatcher:
|
|||||||
)
|
)
|
||||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||||
|
|
||||||
|
def _get_lora_cases(self) -> list[tuple[bool, int]]:
|
||||||
|
"""
|
||||||
|
Returns list of (has_lora, num_active_loras) tuples for graph capture.
|
||||||
|
|
||||||
|
Returns cases for each num_active_loras from 1 to max_loras.
|
||||||
|
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
||||||
|
|
||||||
|
Note: When speculative decoding is enabled, we fall back to capturing
|
||||||
|
only with max_loras to avoid conflicts with torch.compile during
|
||||||
|
CUDA graph capture.
|
||||||
|
"""
|
||||||
|
if not self.vllm_config.lora_config:
|
||||||
|
return [(False, 0)]
|
||||||
|
|
||||||
|
max_loras = self.vllm_config.lora_config.max_loras
|
||||||
|
|
||||||
|
# When speculative decoding is enabled, only capture with max_loras
|
||||||
|
# to avoid torch.compile conflicts during CUDA graph capture
|
||||||
|
if self.vllm_config.speculative_config is not None:
|
||||||
|
lora_cases = [(True, max_loras)]
|
||||||
|
if self.compilation_config.cudagraph_specialize_lora:
|
||||||
|
lora_cases.append((False, 0))
|
||||||
|
return lora_cases
|
||||||
|
|
||||||
|
# Capture for each num_active_loras from 1 to max_loras
|
||||||
|
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
||||||
|
# Also capture the no-lora case
|
||||||
|
if self.compilation_config.cudagraph_specialize_lora:
|
||||||
|
lora_cases.append((False, 0))
|
||||||
|
return lora_cases
|
||||||
|
|
||||||
def initialize_cudagraph_keys(
|
def initialize_cudagraph_keys(
|
||||||
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
|
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
|
||||||
):
|
):
|
||||||
@ -94,26 +131,23 @@ class CudagraphDispatcher:
|
|||||||
# get the correct cudagraph mode after backend support is resolved.
|
# get the correct cudagraph mode after backend support is resolved.
|
||||||
self.cudagraph_mode = cudagraph_mode
|
self.cudagraph_mode = cudagraph_mode
|
||||||
|
|
||||||
# LoRA activation cases to specialize the cuda graphs on
|
# Track whether we have LoRA config (always specialize on count)
|
||||||
if self.vllm_config.lora_config:
|
self.has_lora_config = self.vllm_config.lora_config is not None
|
||||||
if self.compilation_config.cudagraph_specialize_lora:
|
|
||||||
lora_cases = [True, False]
|
# Get LoRA cases to capture
|
||||||
else:
|
lora_cases = self._get_lora_cases()
|
||||||
lora_cases = [True]
|
|
||||||
else:
|
|
||||||
lora_cases = [False]
|
|
||||||
|
|
||||||
# Note: we create all valid keys for cudagraph here but do not
|
# Note: we create all valid keys for cudagraph here but do not
|
||||||
# guarantee all keys would be used. For example, if we allow lazy
|
# guarantee all keys would be used. For example, if we allow lazy
|
||||||
# capturing in future PR, some keys may never be triggered.
|
# capturing in future PR, some keys may never be triggered.
|
||||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||||
for bs, has_lora in product(
|
for bs, (has_lora, num_active_loras) in product(
|
||||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||||
):
|
):
|
||||||
self.add_cudagraph_key(
|
self.add_cudagraph_key(
|
||||||
cudagraph_mode.mixed_mode(),
|
cudagraph_mode.mixed_mode(),
|
||||||
self._create_padded_batch_descriptor(
|
self._create_padded_batch_descriptor(
|
||||||
bs, False, has_lora
|
bs, False, has_lora, num_active_loras
|
||||||
).relax_for_mixed_batch_cudagraphs(),
|
).relax_for_mixed_batch_cudagraphs(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,10 +166,14 @@ class CudagraphDispatcher:
|
|||||||
for x in self.compilation_config.cudagraph_capture_sizes
|
for x in self.compilation_config.cudagraph_capture_sizes
|
||||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||||
]
|
]
|
||||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
for bs, (has_lora, num_active_loras) in product(
|
||||||
|
cudagraph_capture_sizes_for_decode, lora_cases
|
||||||
|
):
|
||||||
self.add_cudagraph_key(
|
self.add_cudagraph_key(
|
||||||
CUDAGraphMode.FULL,
|
CUDAGraphMode.FULL,
|
||||||
self._create_padded_batch_descriptor(bs, True, has_lora),
|
self._create_padded_batch_descriptor(
|
||||||
|
bs, True, has_lora, num_active_loras
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.keys_initialized = True
|
self.keys_initialized = True
|
||||||
@ -146,12 +184,20 @@ class CudagraphDispatcher:
|
|||||||
uniform_decode: bool,
|
uniform_decode: bool,
|
||||||
has_lora: bool,
|
has_lora: bool,
|
||||||
disable_full: bool = False,
|
disable_full: bool = False,
|
||||||
|
num_active_loras: int = 0,
|
||||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||||
"""
|
"""
|
||||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||||
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||||
to a graph that supports a more general batch (uniform to non-uniform).
|
to a graph that supports a more general batch (uniform to non-uniform).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tokens: Number of tokens in the batch.
|
||||||
|
uniform_decode: Whether this is a uniform decode batch.
|
||||||
|
has_lora: Whether this batch has active LoRA adapters.
|
||||||
|
disable_full: Whether to disable full cudagraph mode.
|
||||||
|
num_active_loras: Number of distinct active LoRA adapters.
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
not self.keys_initialized
|
not self.keys_initialized
|
||||||
@ -160,8 +206,18 @@ class CudagraphDispatcher:
|
|||||||
):
|
):
|
||||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||||
|
|
||||||
|
# When speculative decoding is enabled, always use max_loras for lookup
|
||||||
|
# since we only capture graphs with max_loras
|
||||||
|
effective_num_active_loras = num_active_loras
|
||||||
|
if (
|
||||||
|
self.vllm_config.speculative_config is not None
|
||||||
|
and self.vllm_config.lora_config is not None
|
||||||
|
and has_lora
|
||||||
|
):
|
||||||
|
effective_num_active_loras = self.vllm_config.lora_config.max_loras
|
||||||
|
|
||||||
batch_desc = self._create_padded_batch_descriptor(
|
batch_desc = self._create_padded_batch_descriptor(
|
||||||
num_tokens, uniform_decode, has_lora
|
num_tokens, uniform_decode, has_lora, effective_num_active_loras
|
||||||
)
|
)
|
||||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||||
|
|
||||||
|
|||||||
@ -2820,6 +2820,7 @@ class GPUModelRunner(
|
|||||||
# be improved in model runner v2)
|
# be improved in model runner v2)
|
||||||
force_uniform_decode: bool | None = None,
|
force_uniform_decode: bool | None = None,
|
||||||
force_has_lora: bool | None = None,
|
force_has_lora: bool | None = None,
|
||||||
|
force_num_active_loras: int | None = None,
|
||||||
num_encoder_reqs: int = 0,
|
num_encoder_reqs: int = 0,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
CUDAGraphMode,
|
CUDAGraphMode,
|
||||||
@ -2841,11 +2842,13 @@ class GPUModelRunner(
|
|||||||
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
has_lora = (
|
# Compute LoRA state for cudagraph dispatch
|
||||||
len(self.input_batch.lora_id_to_lora_request) > 0
|
num_active_loras = (
|
||||||
if force_has_lora is None
|
force_num_active_loras
|
||||||
else force_has_lora
|
if force_num_active_loras is not None
|
||||||
|
else len(self.input_batch.lora_id_to_lora_request)
|
||||||
)
|
)
|
||||||
|
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
|
||||||
|
|
||||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||||
dispatch_cudagraph = (
|
dispatch_cudagraph = (
|
||||||
@ -2854,6 +2857,7 @@ class GPUModelRunner(
|
|||||||
has_lora=has_lora,
|
has_lora=has_lora,
|
||||||
uniform_decode=uniform_decode,
|
uniform_decode=uniform_decode,
|
||||||
disable_full=disable_full,
|
disable_full=disable_full,
|
||||||
|
num_active_loras=num_active_loras,
|
||||||
)
|
)
|
||||||
if not force_eager
|
if not force_eager
|
||||||
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
||||||
@ -4058,6 +4062,7 @@ class GPUModelRunner(
|
|||||||
remove_lora: bool = True,
|
remove_lora: bool = True,
|
||||||
activate_lora: bool = False,
|
activate_lora: bool = False,
|
||||||
is_graph_capturing: bool = False,
|
is_graph_capturing: bool = False,
|
||||||
|
num_active_loras: int = 0,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Run a dummy forward pass to warm up/profile run or capture the
|
Run a dummy forward pass to warm up/profile run or capture the
|
||||||
@ -4081,6 +4086,8 @@ class GPUModelRunner(
|
|||||||
(1 token) and prefill (multiple tokens) requests.
|
(1 token) and prefill (multiple tokens) requests.
|
||||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||||
|
num_active_loras: Number of distinct active LoRAs to capture for.
|
||||||
|
Only used when cudagraph_specialize_lora_count is True.
|
||||||
"""
|
"""
|
||||||
if supports_mm_encoder_only(self.model):
|
if supports_mm_encoder_only(self.model):
|
||||||
# The current dummy run only covers LM execution, so we can skip it.
|
# The current dummy run only covers LM execution, so we can skip it.
|
||||||
@ -4162,6 +4169,9 @@ class GPUModelRunner(
|
|||||||
# activated later in the context manager, but we need to know the
|
# activated later in the context manager, but we need to know the
|
||||||
# LoRA state when determining the batch descriptor for capture
|
# LoRA state when determining the batch descriptor for capture
|
||||||
force_has_lora=activate_lora,
|
force_has_lora=activate_lora,
|
||||||
|
# `force_num_active_loras` is used for cudagraph capture; because we
|
||||||
|
# need to capture graphs for specific num_active_loras counts
|
||||||
|
force_num_active_loras=num_active_loras if activate_lora else 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4225,6 +4235,7 @@ class GPUModelRunner(
|
|||||||
num_sampled_tokens,
|
num_sampled_tokens,
|
||||||
activate_lora,
|
activate_lora,
|
||||||
remove_lora,
|
remove_lora,
|
||||||
|
num_active_loras,
|
||||||
):
|
):
|
||||||
# Make sure padding doesn't exceed max_num_tokens
|
# Make sure padding doesn't exceed max_num_tokens
|
||||||
assert num_tokens_padded <= self.max_num_tokens
|
assert num_tokens_padded <= self.max_num_tokens
|
||||||
@ -4587,6 +4598,24 @@ class GPUModelRunner(
|
|||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
def _get_lora_capture_cases(self) -> list[tuple[bool, int]]:
|
||||||
|
"""
|
||||||
|
Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture.
|
||||||
|
|
||||||
|
Returns cases for each num_active_loras from 1 to max_loras.
|
||||||
|
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
||||||
|
"""
|
||||||
|
if not self.lora_config:
|
||||||
|
return [(False, 0)]
|
||||||
|
|
||||||
|
max_loras = self.lora_config.max_loras
|
||||||
|
# Capture for each num_active_loras from 1 to max_loras
|
||||||
|
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
||||||
|
# Also capture the no-lora case if cudagraph_specialize_lora is True
|
||||||
|
if self.compilation_config.cudagraph_specialize_lora:
|
||||||
|
lora_cases.append((False, 0))
|
||||||
|
return lora_cases
|
||||||
|
|
||||||
def capture_model(self) -> int:
|
def capture_model(self) -> int:
|
||||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -4624,13 +4653,8 @@ class GPUModelRunner(
|
|||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
assert cudagraph_mode is not None
|
assert cudagraph_mode is not None
|
||||||
|
|
||||||
if self.lora_config:
|
# Build LoRA cases: list of (has_lora, num_active_loras) tuples
|
||||||
if self.compilation_config.cudagraph_specialize_lora:
|
lora_cases = self._get_lora_capture_cases()
|
||||||
lora_cases = [True, False]
|
|
||||||
else:
|
|
||||||
lora_cases = [True]
|
|
||||||
else:
|
|
||||||
lora_cases = [False]
|
|
||||||
|
|
||||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||||
@ -4695,10 +4719,23 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
def _capture_cudagraphs(
|
def _capture_cudagraphs(
|
||||||
self,
|
self,
|
||||||
compilation_cases: list[tuple[int, bool]],
|
compilation_cases: list[tuple[int, tuple[bool, int]]],
|
||||||
cudagraph_runtime_mode: CUDAGraphMode,
|
cudagraph_runtime_mode: CUDAGraphMode,
|
||||||
uniform_decode: bool,
|
uniform_decode: bool,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Capture CUDA graphs for the given compilation cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compilation_cases: List of (num_tokens, (has_lora, num_active_loras))
|
||||||
|
tuples.
|
||||||
|
- num_tokens: batch size to capture
|
||||||
|
- has_lora: whether LoRA is active for this capture
|
||||||
|
- num_active_loras: number of distinct active LoRAs
|
||||||
|
(0 if not specializing)
|
||||||
|
cudagraph_runtime_mode: FULL or PIECEWISE cudagraph mode.
|
||||||
|
uniform_decode: Whether this is a uniform decode batch.
|
||||||
|
"""
|
||||||
assert (
|
assert (
|
||||||
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||||
and cudagraph_runtime_mode.valid_runtime_modes()
|
and cudagraph_runtime_mode.valid_runtime_modes()
|
||||||
@ -4716,7 +4753,7 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
# We skip EPLB here since we don't want to record dummy metrics
|
||||||
for num_tokens, activate_lora in compilation_cases:
|
for num_tokens, (activate_lora, num_active_loras) in compilation_cases:
|
||||||
# We currently only capture ubatched graphs when its a FULL
|
# We currently only capture ubatched graphs when its a FULL
|
||||||
# cudagraph, a uniform decode batch, and the number of tokens
|
# cudagraph, a uniform decode batch, and the number of tokens
|
||||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||||
@ -4748,6 +4785,7 @@ class GPUModelRunner(
|
|||||||
skip_eplb=True,
|
skip_eplb=True,
|
||||||
remove_lora=False,
|
remove_lora=False,
|
||||||
activate_lora=activate_lora,
|
activate_lora=activate_lora,
|
||||||
|
num_active_loras=num_active_loras,
|
||||||
)
|
)
|
||||||
self._dummy_run(
|
self._dummy_run(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@ -4757,6 +4795,7 @@ class GPUModelRunner(
|
|||||||
skip_eplb=True,
|
skip_eplb=True,
|
||||||
remove_lora=False,
|
remove_lora=False,
|
||||||
activate_lora=activate_lora,
|
activate_lora=activate_lora,
|
||||||
|
num_active_loras=num_active_loras,
|
||||||
is_graph_capturing=True,
|
is_graph_capturing=True,
|
||||||
)
|
)
|
||||||
self.maybe_remove_all_loras(self.lora_config)
|
self.maybe_remove_all_loras(self.lora_config)
|
||||||
|
|||||||
@ -4,7 +4,10 @@
|
|||||||
Define LoRA functionality mixin for model runners.
|
Define LoRA functionality mixin for model runners.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, TypeAlias
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -20,7 +23,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
|||||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||||
|
|
||||||
InputBatch = TPUInputBatch | GPUInputBatch
|
if TYPE_CHECKING:
|
||||||
|
InputBatch: TypeAlias = TPUInputBatch | GPUInputBatch
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -129,7 +133,20 @@ class LoRAModelRunnerMixin:
|
|||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
num_sampled_tokens: np.ndarray | None = None,
|
num_sampled_tokens: np.ndarray | None = None,
|
||||||
activate_lora: bool = True,
|
activate_lora: bool = True,
|
||||||
|
num_active_loras: int = 0,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Context manager to select dummy LoRAs for capture/warmup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_config: LoRA configuration, or None if LoRA is disabled.
|
||||||
|
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||||
|
num_sampled_tokens: Array of sampled token counts per request.
|
||||||
|
activate_lora: Whether to activate LoRAs (False means no LoRA).
|
||||||
|
num_active_loras: Number of distinct active LoRAs to use.
|
||||||
|
- 0: Use all max_loras (default behavior for has_lora=True).
|
||||||
|
- >0: Use exactly this many distinct LoRAs.
|
||||||
|
"""
|
||||||
if num_sampled_tokens is None:
|
if num_sampled_tokens is None:
|
||||||
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
|
||||||
|
|
||||||
@ -140,13 +157,24 @@ class LoRAModelRunnerMixin:
|
|||||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||||
|
|
||||||
num_reqs = len(num_scheduled_tokens)
|
num_reqs = len(num_scheduled_tokens)
|
||||||
num_loras = lora_config.max_loras
|
max_loras = lora_config.max_loras
|
||||||
|
|
||||||
|
# Determine how many distinct LoRAs to use
|
||||||
|
if not activate_lora:
|
||||||
|
# No LoRA active
|
||||||
|
effective_num_loras = 0
|
||||||
|
elif num_active_loras > 0:
|
||||||
|
# Specific number of active LoRAs requested
|
||||||
|
effective_num_loras = min(num_active_loras, max_loras)
|
||||||
|
else:
|
||||||
|
# Default: use all max_loras
|
||||||
|
effective_num_loras = max_loras
|
||||||
|
|
||||||
# Make prompt lora mapping
|
# Make prompt lora mapping
|
||||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||||
if activate_lora:
|
if effective_num_loras > 0:
|
||||||
prompt_lora_mapping = (
|
prompt_lora_mapping = (
|
||||||
np.arange(num_reqs, dtype=np.int32) % num_loras
|
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
|
||||||
) + 1
|
) + 1
|
||||||
else:
|
else:
|
||||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||||
@ -157,19 +185,20 @@ class LoRAModelRunnerMixin:
|
|||||||
# Make token lora mapping
|
# Make token lora mapping
|
||||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||||
|
|
||||||
# Make dummy lora requests
|
# Make dummy lora requests (only for the active LoRAs)
|
||||||
lora_requests: set[LoRARequest] = {
|
lora_requests: set[LoRARequest] = {
|
||||||
LoRARequest(
|
LoRARequest(
|
||||||
lora_name=f"warmup_{lora_id}",
|
lora_name=f"warmup_{lora_id}",
|
||||||
lora_int_id=lora_id,
|
lora_int_id=lora_id,
|
||||||
lora_path="/not/a/real/path",
|
lora_path="/not/a/real/path",
|
||||||
)
|
)
|
||||||
for lora_id in range(1, num_loras + 1)
|
for lora_id in range(1, effective_num_loras + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
self._set_active_loras(
|
if lora_requests:
|
||||||
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
self._set_active_loras(
|
||||||
)
|
tuple(sample_lora_mapping), tuple(token_lora_mapping), lora_requests
|
||||||
|
)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -181,11 +210,27 @@ class LoRAModelRunnerMixin:
|
|||||||
num_sampled_tokens: np.ndarray,
|
num_sampled_tokens: np.ndarray,
|
||||||
activate_lora: bool = True,
|
activate_lora: bool = True,
|
||||||
remove_lora: bool = True,
|
remove_lora: bool = True,
|
||||||
|
num_active_loras: int = 0,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Context manager for dummy runs with LoRA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_config: LoRA configuration.
|
||||||
|
num_scheduled_tokens: Array of scheduled token counts per request.
|
||||||
|
num_sampled_tokens: Array of sampled token counts per request.
|
||||||
|
activate_lora: Whether to activate LoRAs.
|
||||||
|
remove_lora: Whether to remove LoRAs after the context exits.
|
||||||
|
num_active_loras: Number of distinct active LoRAs to use.
|
||||||
|
"""
|
||||||
with (
|
with (
|
||||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||||
self.maybe_select_dummy_loras(
|
self.maybe_select_dummy_loras(
|
||||||
lora_config, num_scheduled_tokens, num_sampled_tokens, activate_lora
|
lora_config,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_sampled_tokens,
|
||||||
|
activate_lora,
|
||||||
|
num_active_loras,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user