mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:15:40 +08:00
[LoRA] LoRA cuda graph specialization (#25914)
Signed-off-by: Andy Lo <andy@mistral.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
f32bf7582e
commit
b63f2143f8
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import vllm
|
||||
import vllm.config
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import create_new_process_for_each_test, multi_gpu_test
|
||||
@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
||||
def test_chatglm3_lora(chatglm3_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
max_model_len=512,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_loras=2,
|
||||
max_num_seqs=16,
|
||||
max_lora_rank=64,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files):
|
||||
def test_chatglm3_lora_tp4(chatglm3_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
max_model_len=512,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_loras=2,
|
||||
max_lora_rank=64,
|
||||
max_num_seqs=16,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=False,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
|
||||
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
||||
@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
|
||||
# more GPU memory causing vLLM to OOM
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
max_model_len=512,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_loras=2,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=True,
|
||||
gpu_memory_utilization=0.85,
|
||||
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||
cudagraph_specialize_lora=False,
|
||||
),
|
||||
)
|
||||
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||
|
||||
@ -3,7 +3,10 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
import vllm.config
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
@ -100,7 +103,8 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_llama_lora(sql_lora_files):
|
||||
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
|
||||
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
@ -108,6 +112,9 @@ def test_llama_lora(sql_lora_files):
|
||||
# also test odd max_num_seqs
|
||||
max_num_seqs=13,
|
||||
max_loras=4,
|
||||
compilation_config=vllm.config.CompilationConfig(
|
||||
cudagraph_specialize_lora=cudagraph_specialize_lora,
|
||||
),
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -366,6 +366,14 @@ class CompilationConfig:
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
||||
FULL_AND_PIECEWISE instead.
|
||||
"""
|
||||
cudagraph_specialize_lora: bool = True
|
||||
"""Whether to create separate cuda graphs for cases with and without active
|
||||
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
|
||||
for all cases, incurring the overhead of running LoRA ops even when no
|
||||
adapters are active. Setting this to True will remove this overhead at the
|
||||
cost of increased startup time and slightly higher memory usage.
|
||||
When `enable_lora` is False, this option has no effect.
|
||||
"""
|
||||
|
||||
use_inductor_graph_partition: bool = False
|
||||
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
||||
|
||||
@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple):
|
||||
False can also be used for an uniform decode batch to dispatch to the
|
||||
cudagraph supporting non-uniform batches.
|
||||
"""
|
||||
has_lora: bool = False
|
||||
"""
|
||||
Whether this batch has active LoRA adapters.
|
||||
"""
|
||||
|
||||
@property
|
||||
def non_uniform(self) -> "BatchDescriptor":
|
||||
"""
|
||||
Return a non-uniform version of current batch descriptor.
|
||||
"""
|
||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||
return BatchDescriptor(
|
||||
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
|
||||
)
|
||||
|
||||
|
||||
def _compute_sp_num_tokens(
|
||||
|
||||
@ -169,6 +169,8 @@ def _lora_shrink(
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
output_tensor.zero_()
|
||||
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
|
||||
_get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||
)
|
||||
|
||||
@ -205,15 +205,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
if buffer is None:
|
||||
assert buffer is None, (
|
||||
"To minimize overhead, the buffer should be created by "
|
||||
".add_lora_linear() instead of being passed in."
|
||||
)
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros( # type: ignore
|
||||
(len(output_slices), x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
# Note: buffer is zeroed inside the shrink op
|
||||
buffer = torch.empty(
|
||||
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
|
||||
)
|
||||
|
||||
self.add_shrink(
|
||||
buffer, # type: ignore
|
||||
x,
|
||||
@ -260,10 +263,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
|
||||
assert buffer is None, (
|
||||
"To minimize overhead, the buffer should be created by "
|
||||
".add_lora_linear() instead of being passed in."
|
||||
)
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
# Note: buffer is zeroed inside the shrink op
|
||||
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
|
||||
lora_shrink(
|
||||
x,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from itertools import product
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
@ -67,14 +68,27 @@ class CudagraphDispatcher:
|
||||
):
|
||||
# This should be called only after attention backend is initialized.
|
||||
|
||||
# LoRA activation cases to specialize the cuda graphs on
|
||||
if self.vllm_config.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
|
||||
# Note: we create all valid keys for cudagraph here but do not
|
||||
# guarantee all keys would be used. For example, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs in self.compilation_config.cudagraph_capture_sizes:
|
||||
for bs, has_lora in product(
|
||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=False),
|
||||
BatchDescriptor(
|
||||
num_tokens=bs, uniform_decode=False, has_lora=has_lora
|
||||
),
|
||||
)
|
||||
|
||||
# if decode cudagraph mode is FULL, and we don't already have mixed
|
||||
@ -92,10 +106,12 @@ class CudagraphDispatcher:
|
||||
for x in self.compilation_config.cudagraph_capture_sizes
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs in cudagraph_capture_sizes_for_decode:
|
||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
BatchDescriptor(num_tokens=bs, uniform_decode=True),
|
||||
BatchDescriptor(
|
||||
num_tokens=bs, uniform_decode=True, has_lora=has_lora
|
||||
),
|
||||
)
|
||||
self.keys_initialized = True
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||
|
||||
import numpy as np
|
||||
@ -2469,7 +2470,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
|
||||
)
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=num_input_tokens, uniform_decode=uniform_decode
|
||||
num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
||||
)
|
||||
cudagraph_runtime_mode, batch_descriptor = (
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
|
||||
@ -3193,6 +3196,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
remove_lora: bool = True,
|
||||
activate_lora: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run a dummy forward pass to warm up/profile run or capture the
|
||||
@ -3215,6 +3219,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
create_mixed_batch: If True, create a mixed batch with both decode
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||
"""
|
||||
assert (
|
||||
cudagraph_runtime_mode is None
|
||||
@ -3364,7 +3369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
with self.maybe_dummy_run_with_lora(
|
||||
self.lora_config, num_scheduled_tokens, remove_lora
|
||||
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
|
||||
):
|
||||
# Make sure padding doesn't exceed max_num_tokens
|
||||
assert num_tokens_after_padding <= self.max_num_tokens
|
||||
@ -3411,6 +3416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
BatchDescriptor(
|
||||
num_tokens=num_tokens_after_padding,
|
||||
uniform_decode=uniform_decode,
|
||||
has_lora=activate_lora and self.lora_config is not None,
|
||||
)
|
||||
)
|
||||
if not is_profile
|
||||
@ -3769,10 +3775,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
|
||||
if self.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
|
||||
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
|
||||
compilation_cases = list(
|
||||
product(reversed(self.cudagraph_batch_sizes), lora_cases)
|
||||
)
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
@ -3793,7 +3810,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for x in self.cudagraph_batch_sizes
|
||||
if max_num_tokens >= x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
|
||||
compilation_cases_decode = list(
|
||||
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
|
||||
)
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
@ -3823,7 +3842,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _capture_cudagraphs(
|
||||
self,
|
||||
compilation_cases: list[int],
|
||||
compilation_cases: list[tuple[int, bool]],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool,
|
||||
):
|
||||
@ -3844,7 +3863,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
for num_tokens, activate_lora in compilation_cases:
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
@ -3875,6 +3894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
@ -3883,6 +3903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
|
||||
@ -120,7 +120,10 @@ class LoRAModelRunnerMixin:
|
||||
|
||||
@contextmanager
|
||||
def maybe_select_dummy_loras(
|
||||
self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray
|
||||
self,
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
activate_lora: bool = True,
|
||||
):
|
||||
if lora_config is None:
|
||||
yield
|
||||
@ -133,7 +136,12 @@ class LoRAModelRunnerMixin:
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1
|
||||
if activate_lora:
|
||||
prompt_lora_mapping = (
|
||||
np.arange(num_reqs, dtype=np.int32) % num_loras
|
||||
) + 1
|
||||
else:
|
||||
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||
@ -159,11 +167,14 @@ class LoRAModelRunnerMixin:
|
||||
self,
|
||||
lora_config: LoRAConfig | None,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
activate_lora: bool = True,
|
||||
remove_lora: bool = True,
|
||||
):
|
||||
with (
|
||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
|
||||
self.maybe_select_dummy_loras(
|
||||
lora_config, num_scheduled_tokens, activate_lora
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user