mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 03:51:21 +08:00
Construting grid with num of active lora
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
parent
e475bdae39
commit
e95e2145c5
@ -240,7 +240,7 @@ def _lora_expand(
|
|||||||
# Each LoRA receives its own set of thread blocks for output
|
# Each LoRA receives its own set of thread blocks for output
|
||||||
# computation. If some LoRA doesn't have any tokens to process, its
|
# computation. If some LoRA doesn't have any tokens to process, its
|
||||||
# thread blocks simply exit.
|
# thread blocks simply exit.
|
||||||
MAX_LORAS,
|
num_active_loras,
|
||||||
)
|
)
|
||||||
use_gdc = supports_pdl(inputs.device)
|
use_gdc = supports_pdl(inputs.device)
|
||||||
_lora_expand_kernel[grid](
|
_lora_expand_kernel[grid](
|
||||||
|
|||||||
@ -220,7 +220,7 @@ def _lora_shrink(
|
|||||||
# Each LoRA receives its own set of thread blocks for output
|
# Each LoRA receives its own set of thread blocks for output
|
||||||
# computation. If some LoRA doesn't have any tokens to process, its
|
# computation. If some LoRA doesn't have any tokens to process, its
|
||||||
# thread blocks exit early.
|
# thread blocks exit early.
|
||||||
MAX_LORAS,
|
num_active_loras,
|
||||||
)
|
)
|
||||||
use_gdc = supports_pdl(inputs.device)
|
use_gdc = supports_pdl(inputs.device)
|
||||||
_lora_shrink_kernel[grid](
|
_lora_shrink_kernel[grid](
|
||||||
|
|||||||
@ -58,7 +58,25 @@ class CudagraphDispatcher:
|
|||||||
|
|
||||||
self.keys_initialized = False
|
self.keys_initialized = False
|
||||||
self.specialize_lora_count = False
|
self.specialize_lora_count = False
|
||||||
self.specialize_lora_count = False
|
|
||||||
|
def _get_lora_cases(self) -> list[tuple[bool, int]]:
|
||||||
|
"""
|
||||||
|
Returns list of (has_lora, num_active_loras) tuples for CUDA graph
|
||||||
|
capture. This is the single source of truth for LoRA capture cases.
|
||||||
|
"""
|
||||||
|
lora_config = self.vllm_config.lora_config
|
||||||
|
if lora_config is None:
|
||||||
|
# No LoRA configured - single case with no LoRA
|
||||||
|
return [(False, 0)]
|
||||||
|
|
||||||
|
# LoRA is enabled - capture graphs for different active LoRA counts
|
||||||
|
# Always include the no-LoRA case (for requests without adapters)
|
||||||
|
cases: list[tuple[bool, int]] = [(False, 0)]
|
||||||
|
|
||||||
|
for n in range(1, lora_config.max_loras + 1):
|
||||||
|
cases.append((True, n))
|
||||||
|
|
||||||
|
return cases
|
||||||
|
|
||||||
def _create_padded_batch_descriptor(
|
def _create_padded_batch_descriptor(
|
||||||
self,
|
self,
|
||||||
@ -66,11 +84,6 @@ class CudagraphDispatcher:
|
|||||||
uniform_decode: bool,
|
uniform_decode: bool,
|
||||||
has_lora: bool,
|
has_lora: bool,
|
||||||
num_active_loras: int = 0,
|
num_active_loras: int = 0,
|
||||||
self,
|
|
||||||
num_tokens: int,
|
|
||||||
uniform_decode: bool,
|
|
||||||
has_lora: bool,
|
|
||||||
num_active_loras: int = 0,
|
|
||||||
) -> BatchDescriptor:
|
) -> BatchDescriptor:
|
||||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
uniform_decode_query_len = self.uniform_decode_query_len
|
uniform_decode_query_len = self.uniform_decode_query_len
|
||||||
@ -89,7 +102,6 @@ class CudagraphDispatcher:
|
|||||||
uniform=uniform_decode,
|
uniform=uniform_decode,
|
||||||
has_lora=has_lora,
|
has_lora=has_lora,
|
||||||
num_active_loras=num_active_loras,
|
num_active_loras=num_active_loras,
|
||||||
num_active_loras=num_active_loras,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_cudagraph_key(
|
def add_cudagraph_key(
|
||||||
|
|||||||
@ -4598,24 +4598,6 @@ class GPUModelRunner(
|
|||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def _get_lora_capture_cases(self) -> list[tuple[bool, int]]:
|
|
||||||
"""
|
|
||||||
Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture.
|
|
||||||
|
|
||||||
Returns cases for each num_active_loras from 1 to max_loras.
|
|
||||||
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
|
||||||
"""
|
|
||||||
if not self.lora_config:
|
|
||||||
return [(False, 0)]
|
|
||||||
|
|
||||||
max_loras = self.lora_config.max_loras
|
|
||||||
# Capture for each num_active_loras from 1 to max_loras
|
|
||||||
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
|
||||||
# Also capture the no-lora case if cudagraph_specialize_lora is True
|
|
||||||
if self.compilation_config.cudagraph_specialize_lora:
|
|
||||||
lora_cases.append((False, 0))
|
|
||||||
return lora_cases
|
|
||||||
|
|
||||||
def capture_model(self) -> int:
|
def capture_model(self) -> int:
|
||||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user