mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 17:07:03 +08:00
Capture multiple cuda graph across various active loras
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
parent
b28be8ea57
commit
e475bdae39
@ -58,6 +58,7 @@ class CudagraphDispatcher:
|
||||
|
||||
self.keys_initialized = False
|
||||
self.specialize_lora_count = False
|
||||
self.specialize_lora_count = False
|
||||
|
||||
def _create_padded_batch_descriptor(
|
||||
self,
|
||||
@ -65,6 +66,11 @@ class CudagraphDispatcher:
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
num_active_loras: int = 0,
|
||||
self,
|
||||
num_tokens: int,
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
num_active_loras: int = 0,
|
||||
) -> BatchDescriptor:
|
||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
uniform_decode_query_len = self.uniform_decode_query_len
|
||||
@ -83,6 +89,7 @@ class CudagraphDispatcher:
|
||||
uniform=uniform_decode,
|
||||
has_lora=has_lora,
|
||||
num_active_loras=num_active_loras,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
|
||||
def add_cudagraph_key(
|
||||
@ -93,29 +100,6 @@ class CudagraphDispatcher:
|
||||
)
|
||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def _get_lora_cases(self) -> list[tuple[bool, int]]:
|
||||
"""
|
||||
Returns list of (has_lora, num_active_loras) tuples for graph capture.
|
||||
|
||||
Returns cases for each num_active_loras from 1 to max_loras.
|
||||
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
||||
|
||||
Note: When speculative decoding is enabled, we fall back to capturing
|
||||
only with max_loras to avoid conflicts with torch.compile during
|
||||
CUDA graph capture.
|
||||
"""
|
||||
if not self.vllm_config.lora_config:
|
||||
return [(False, 0)]
|
||||
|
||||
max_loras = self.vllm_config.lora_config.max_loras
|
||||
|
||||
# Capture for each num_active_loras from 1 to max_loras
|
||||
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
||||
# Also capture the no-lora case
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases.append((False, 0))
|
||||
return lora_cases
|
||||
|
||||
def initialize_cudagraph_keys(
|
||||
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
|
||||
):
|
||||
|
||||
@ -4598,6 +4598,24 @@ class GPUModelRunner(
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def _get_lora_capture_cases(self) -> list[tuple[bool, int]]:
|
||||
"""
|
||||
Returns list of (has_lora, num_active_loras) tuples for CUDA graph capture.
|
||||
|
||||
Returns cases for each num_active_loras from 1 to max_loras.
|
||||
If cudagraph_specialize_lora is True, also includes the no-lora case.
|
||||
"""
|
||||
if not self.lora_config:
|
||||
return [(False, 0)]
|
||||
|
||||
max_loras = self.lora_config.max_loras
|
||||
# Capture for each num_active_loras from 1 to max_loras
|
||||
lora_cases = [(True, n) for n in range(1, max_loras + 1)]
|
||||
# Also capture the no-lora case if cudagraph_specialize_lora is True
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases.append((False, 0))
|
||||
return lora_cases
|
||||
|
||||
def capture_model(self) -> int:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
logger.warning(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user