[LoRA] LoRA cuda graph specialization (#25914)

Signed-off-by: Andy Lo <andy@mistral.ai>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Andy Lo 2025-10-20 05:21:09 +01:00 committed by GitHub
parent f32bf7582e
commit b63f2143f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 122 additions and 34 deletions

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import vllm import vllm
import vllm.config
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from ..utils import create_new_process_for_each_test, multi_gpu_test from ..utils import create_new_process_for_each_test, multi_gpu_test
@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
def test_chatglm3_lora(chatglm3_lora_files): def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
max_model_len=1024, max_model_len=512,
enable_lora=True, enable_lora=True,
max_loras=4, max_loras=2,
max_num_seqs=16,
max_lora_rank=64, max_lora_rank=64,
trust_remote_code=True, trust_remote_code=True,
) )
@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files):
def test_chatglm3_lora_tp4(chatglm3_lora_files): def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
max_model_len=1024, max_model_len=512,
enable_lora=True, enable_lora=True,
max_loras=4, max_loras=2,
max_lora_rank=64, max_lora_rank=64,
max_num_seqs=16,
tensor_parallel_size=4, tensor_parallel_size=4,
trust_remote_code=True, trust_remote_code=True,
fully_sharded_loras=False, fully_sharded_loras=False,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
) )
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
# more GPU memory causing vLLM to OOM # more GPU memory causing vLLM to OOM
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
max_model_len=1024, max_model_len=512,
enable_lora=True, enable_lora=True,
max_loras=4, max_loras=2,
max_lora_rank=64, max_lora_rank=64,
tensor_parallel_size=4, tensor_parallel_size=4,
trust_remote_code=True, trust_remote_code=True,
fully_sharded_loras=True, fully_sharded_loras=True,
gpu_memory_utilization=0.85, gpu_memory_utilization=0.85,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
) )
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)): for i in range(len(EXPECTED_LORA_OUTPUT)):

View File

@ -3,7 +3,10 @@
import subprocess import subprocess
import sys import sys
import pytest
import vllm import vllm
import vllm.config
from vllm import LLM from vllm import LLM
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@ -100,7 +103,8 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_llama_lora(sql_lora_files): @pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
llm = vllm.LLM( llm = vllm.LLM(
MODEL_PATH, MODEL_PATH,
tokenizer=sql_lora_files, tokenizer=sql_lora_files,
@ -108,6 +112,9 @@ def test_llama_lora(sql_lora_files):
# also test odd max_num_seqs # also test odd max_num_seqs
max_num_seqs=13, max_num_seqs=13,
max_loras=4, max_loras=4,
compilation_config=vllm.config.CompilationConfig(
cudagraph_specialize_lora=cudagraph_specialize_lora,
),
) )
generate_and_test(llm, sql_lora_files) generate_and_test(llm, sql_lora_files)

View File

@ -366,6 +366,14 @@ class CompilationConfig:
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead. FULL_AND_PIECEWISE instead.
""" """
cudagraph_specialize_lora: bool = True
"""Whether to create separate cuda graphs for cases with and without active
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
for all cases, incurring the overhead of running LoRA ops even when no
adapters are active. Setting this to True will remove this overhead at the
cost of increased startup time and slightly higher memory usage.
When `enable_lora` is False, this option has no effect.
"""
use_inductor_graph_partition: bool = False use_inductor_graph_partition: bool = False
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops. """Use inductor graph partition to split the graph at cudagraph_unsafe ops.

View File

@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple):
False can also be used for an uniform decode batch to dispatch to the False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches. cudagraph supporting non-uniform batches.
""" """
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""
@property @property
def non_uniform(self) -> "BatchDescriptor": def non_uniform(self) -> "BatchDescriptor":
""" """
Return a non-uniform version of current batch descriptor. Return a non-uniform version of current batch descriptor.
""" """
return BatchDescriptor(self.num_tokens, uniform_decode=False) return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
)
def _compute_sp_num_tokens( def _compute_sp_num_tokens(

View File

@ -169,6 +169,8 @@ def _lora_shrink(
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
output_tensor.zero_()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = ( (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
_get_lora_a_ptr(lora_a_weights, inputs.device) _get_lora_a_ptr(lora_a_weights, inputs.device)
) )

View File

@ -205,15 +205,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None: assert buffer is None, (
r = lora_b_stacked[0].size(-1) "To minimize overhead, the buffer should be created by "
# We set the buffer to be float32 by default, refer to: ".add_lora_linear() instead of being passed in."
# https://github.com/triton-lang/triton/issues/1387 )
buffer = torch.zeros( # type: ignore r = lora_b_stacked[0].size(-1)
(len(output_slices), x.size(0), r), # We set the buffer to be float32 by default, refer to:
dtype=torch.float32, # https://github.com/triton-lang/triton/issues/1387
device=x.device, # Note: buffer is zeroed inside the shrink op
) buffer = torch.empty(
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
)
self.add_shrink( self.add_shrink(
buffer, # type: ignore buffer, # type: ignore
x, x,
@ -260,10 +263,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1) r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, refer to: assert buffer is None, (
# https://github.com/triton-lang/triton/issues/1387 "To minimize overhead, the buffer should be created by "
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) ".add_lora_linear() instead of being passed in."
)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
lora_shrink( lora_shrink(
x, x,

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
@ -67,14 +68,27 @@ class CudagraphDispatcher:
): ):
# This should be called only after attention backend is initialized. # This should be called only after attention backend is initialized.
# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
# Note: we create all valid keys for cudagraph here but do not # Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy # guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered. # capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes: for bs, has_lora in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key( self.add_cudagraph_key(
cudagraph_mode.mixed_mode(), cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False), BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
) )
# if decode cudagraph mode is FULL, and we don't already have mixed # if decode cudagraph mode is FULL, and we don't already have mixed
@ -92,10 +106,12 @@ class CudagraphDispatcher:
for x in self.compilation_config.cudagraph_capture_sizes for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len if x <= max_num_tokens and x >= uniform_decode_query_len
] ]
for bs in cudagraph_capture_sizes_for_decode: for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key( self.add_cudagraph_key(
CUDAGraphMode.FULL, CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True), BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
) )
self.keys_initialized = True self.keys_initialized = True

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np import numpy as np
@ -2469,7 +2470,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
) )
batch_descriptor = BatchDescriptor( batch_descriptor = BatchDescriptor(
num_tokens=num_input_tokens, uniform_decode=uniform_decode num_tokens=num_input_tokens,
uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
) )
cudagraph_runtime_mode, batch_descriptor = ( cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
@ -3193,6 +3196,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_profile: bool = False, is_profile: bool = False,
create_mixed_batch: bool = False, create_mixed_batch: bool = False,
remove_lora: bool = True, remove_lora: bool = True,
activate_lora: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Run a dummy forward pass to warm up/profile run or capture the Run a dummy forward pass to warm up/profile run or capture the
@ -3215,6 +3219,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
create_mixed_batch: If True, create a mixed batch with both decode create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests. (1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs.
""" """
assert ( assert (
cudagraph_runtime_mode is None cudagraph_runtime_mode is None
@ -3364,7 +3369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora( with self.maybe_dummy_run_with_lora(
self.lora_config, num_scheduled_tokens, remove_lora self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
): ):
# Make sure padding doesn't exceed max_num_tokens # Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens assert num_tokens_after_padding <= self.max_num_tokens
@ -3411,6 +3416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
BatchDescriptor( BatchDescriptor(
num_tokens=num_tokens_after_padding, num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
) )
) )
if not is_profile if not is_profile
@ -3769,10 +3775,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None assert cudagraph_mode is not None
if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode() cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
compilation_cases = list(reversed(self.cudagraph_batch_sizes)) compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs( self._capture_cudagraphs(
compilation_cases, compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
@ -3793,7 +3810,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for x in self.cudagraph_batch_sizes for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len if max_num_tokens >= x >= self.uniform_decode_query_len
] ]
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs( self._capture_cudagraphs(
compilation_cases=compilation_cases_decode, compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL, cudagraph_runtime_mode=CUDAGraphMode.FULL,
@ -3823,7 +3842,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _capture_cudagraphs( def _capture_cudagraphs(
self, self,
compilation_cases: list[int], compilation_cases: list[tuple[int, bool]],
cudagraph_runtime_mode: CUDAGraphMode, cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool, uniform_decode: bool,
): ):
@ -3844,7 +3863,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases: for num_tokens, activate_lora in compilation_cases:
# We currently only capture ubatched graphs when its a FULL # We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens # cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched # is above the threshold. Otherwise we just capture a non-ubatched
@ -3875,6 +3894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
skip_eplb=True, skip_eplb=True,
remove_lora=False, remove_lora=False,
activate_lora=activate_lora,
) )
self._dummy_run( self._dummy_run(
num_tokens, num_tokens,
@ -3883,6 +3903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
skip_eplb=True, skip_eplb=True,
remove_lora=False, remove_lora=False,
activate_lora=activate_lora,
) )
self.maybe_remove_all_loras(self.lora_config) self.maybe_remove_all_loras(self.lora_config)

View File

@ -120,7 +120,10 @@ class LoRAModelRunnerMixin:
@contextmanager @contextmanager
def maybe_select_dummy_loras( def maybe_select_dummy_loras(
self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
activate_lora: bool = True,
): ):
if lora_config is None: if lora_config is None:
yield yield
@ -133,7 +136,12 @@ class LoRAModelRunnerMixin:
# Make prompt lora mapping # Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario. # Assign LoRA IDs cyclically to simulate a worst-case scenario.
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 if activate_lora:
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % num_loras
) + 1
else:
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make token lora mapping # Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
@ -159,11 +167,14 @@ class LoRAModelRunnerMixin:
self, self,
lora_config: LoRAConfig | None, lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True, remove_lora: bool = True,
): ):
with ( with (
self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens), self.maybe_select_dummy_loras(
lora_config, num_scheduled_tokens, activate_lora
),
): ):
yield yield