[LoRA] LoRA cuda graph specialization (#25914)

Signed-off-by: Andy Lo <andy@mistral.ai>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Andy Lo 2025-10-20 05:21:09 +01:00 committed by GitHub
parent f32bf7582e
commit b63f2143f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 122 additions and 34 deletions

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm
import vllm.config
from vllm.lora.request import LoRARequest
from ..utils import create_new_process_for_each_test, multi_gpu_test
@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
max_model_len=512,
enable_lora=True,
max_loras=4,
max_loras=2,
max_num_seqs=16,
max_lora_rank=64,
trust_remote_code=True,
)
@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files):
def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
max_model_len=512,
enable_lora=True,
max_loras=4,
max_loras=2,
max_lora_rank=64,
max_num_seqs=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
# more GPU memory causing vLLM to OOM
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
max_model_len=512,
enable_lora=True,
max_loras=4,
max_loras=2,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
gpu_memory_utilization=0.85,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):

View File

@ -3,7 +3,10 @@
import subprocess
import sys
import pytest
import vllm
import vllm.config
from vllm import LLM
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@ -100,7 +103,8 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
@create_new_process_for_each_test()
def test_llama_lora(sql_lora_files):
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
llm = vllm.LLM(
MODEL_PATH,
tokenizer=sql_lora_files,
@ -108,6 +112,9 @@ def test_llama_lora(sql_lora_files):
# also test odd max_num_seqs
max_num_seqs=13,
max_loras=4,
compilation_config=vllm.config.CompilationConfig(
cudagraph_specialize_lora=cudagraph_specialize_lora,
),
)
generate_and_test(llm, sql_lora_files)

View File

@ -366,6 +366,14 @@ class CompilationConfig:
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
"""
cudagraph_specialize_lora: bool = True
"""Whether to create separate cuda graphs for cases with and without active
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
for all cases, incurring the overhead of running LoRA ops even when no
adapters are active. Setting this to True will remove this overhead at the
cost of increased startup time and slightly higher memory usage.
When `enable_lora` is False, this option has no effect.
"""
use_inductor_graph_partition: bool = False
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.

View File

@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple):
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""
@property
def non_uniform(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
"""
return BatchDescriptor(self.num_tokens, uniform_decode=False)
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
)
def _compute_sp_num_tokens(

View File

@ -169,6 +169,8 @@ def _lora_shrink(
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
output_tensor.zero_()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
_get_lora_a_ptr(lora_a_weights, inputs.device)
)

View File

@ -205,15 +205,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=torch.float32,
device=x.device,
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty(
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
)
self.add_shrink(
buffer, # type: ignore
x,
@ -260,10 +263,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
lora_shrink(
x,

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
@ -67,14 +68,27 @@ class CudagraphDispatcher:
):
# This should be called only after attention backend is initialized.
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
for bs, has_lora in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False),
BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
)
# if decode cudagraph mode is FULL, and we don't already have mixed
@ -92,10 +106,12 @@ class CudagraphDispatcher:
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True),
BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
)
self.keys_initialized = True

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np
@ -2469,7 +2470,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
)
batch_descriptor = BatchDescriptor(
num_tokens=num_input_tokens, uniform_decode=uniform_decode
num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
)
cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
@ -3193,6 +3196,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_profile: bool = False,
create_mixed_batch: bool = False,
remove_lora: bool = True,
activate_lora: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@ -3215,6 +3219,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
"""
assert (
cudagraph_runtime_mode is None
@ -3364,7 +3369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(
self.lora_config, num_scheduled_tokens, remove_lora
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens
@ -3411,6 +3416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
)
if not is_profile
@ -3769,10 +3775,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
@ -3793,7 +3810,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
@ -3823,7 +3842,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _capture_cudagraphs(
self,
compilation_cases: list[int],
compilation_cases: list[tuple[int, bool]],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
@ -3844,7 +3863,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for num_tokens, activate_lora in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
@ -3875,6 +3894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self._dummy_run(
num_tokens,
@ -3883,6 +3903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self.maybe_remove_all_loras(self.lora_config)

View File

@ -120,7 +120,10 @@ class LoRAModelRunnerMixin:
@contextmanager
def maybe_select_dummy_loras(
self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
activate_lora: bool = True,
):
if lora_config is None:
yield
@ -133,7 +136,12 @@ class LoRAModelRunnerMixin:
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1
if activate_lora:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
) + 1
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
@ -159,11 +167,14 @@ class LoRAModelRunnerMixin:
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True,
):
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, activate_lora
),
):
yield