mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 21:05:01 +08:00
[LoRA] LoRA cuda graph specialization (#25914)
Signed-off-by: Andy Lo <andy@mistral.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
f32bf7582e
commit
b63f2143f8
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
|
import vllm.config
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test, multi_gpu_test
|
from ..utils import create_new_process_for_each_test, multi_gpu_test
|
||||||
@ -53,9 +54,10 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
|
|||||||
def test_chatglm3_lora(chatglm3_lora_files):
|
def test_chatglm3_lora(chatglm3_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
max_model_len=1024,
|
max_model_len=512,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_loras=4,
|
max_loras=2,
|
||||||
|
max_num_seqs=16,
|
||||||
max_lora_rank=64,
|
max_lora_rank=64,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
@ -72,13 +74,17 @@ def test_chatglm3_lora(chatglm3_lora_files):
|
|||||||
def test_chatglm3_lora_tp4(chatglm3_lora_files):
|
def test_chatglm3_lora_tp4(chatglm3_lora_files):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
max_model_len=1024,
|
max_model_len=512,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_loras=4,
|
max_loras=2,
|
||||||
max_lora_rank=64,
|
max_lora_rank=64,
|
||||||
|
max_num_seqs=16,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
fully_sharded_loras=False,
|
fully_sharded_loras=False,
|
||||||
|
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||||
|
cudagraph_specialize_lora=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
||||||
@ -96,14 +102,17 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
|
|||||||
# more GPU memory causing vLLM to OOM
|
# more GPU memory causing vLLM to OOM
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
max_model_len=1024,
|
max_model_len=512,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
max_loras=4,
|
max_loras=2,
|
||||||
max_lora_rank=64,
|
max_lora_rank=64,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
fully_sharded_loras=True,
|
fully_sharded_loras=True,
|
||||||
gpu_memory_utilization=0.85,
|
gpu_memory_utilization=0.85,
|
||||||
|
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
|
||||||
|
cudagraph_specialize_lora=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
||||||
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
|||||||
@ -3,7 +3,10 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
|
import vllm.config
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
@ -100,7 +103,8 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None =
|
|||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_llama_lora(sql_lora_files):
|
@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False])
|
||||||
|
def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool):
|
||||||
llm = vllm.LLM(
|
llm = vllm.LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
tokenizer=sql_lora_files,
|
tokenizer=sql_lora_files,
|
||||||
@ -108,6 +112,9 @@ def test_llama_lora(sql_lora_files):
|
|||||||
# also test odd max_num_seqs
|
# also test odd max_num_seqs
|
||||||
max_num_seqs=13,
|
max_num_seqs=13,
|
||||||
max_loras=4,
|
max_loras=4,
|
||||||
|
compilation_config=vllm.config.CompilationConfig(
|
||||||
|
cudagraph_specialize_lora=cudagraph_specialize_lora,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
generate_and_test(llm, sql_lora_files)
|
generate_and_test(llm, sql_lora_files)
|
||||||
|
|
||||||
|
|||||||
@ -366,6 +366,14 @@ class CompilationConfig:
|
|||||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
||||||
FULL_AND_PIECEWISE instead.
|
FULL_AND_PIECEWISE instead.
|
||||||
"""
|
"""
|
||||||
|
cudagraph_specialize_lora: bool = True
|
||||||
|
"""Whether to create separate cuda graphs for cases with and without active
|
||||||
|
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
|
||||||
|
for all cases, incurring the overhead of running LoRA ops even when no
|
||||||
|
adapters are active. Setting this to True will remove this overhead at the
|
||||||
|
cost of increased startup time and slightly higher memory usage.
|
||||||
|
When `enable_lora` is False, this option has no effect.
|
||||||
|
"""
|
||||||
|
|
||||||
use_inductor_graph_partition: bool = False
|
use_inductor_graph_partition: bool = False
|
||||||
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
||||||
|
|||||||
@ -40,13 +40,19 @@ class BatchDescriptor(NamedTuple):
|
|||||||
False can also be used for an uniform decode batch to dispatch to the
|
False can also be used for an uniform decode batch to dispatch to the
|
||||||
cudagraph supporting non-uniform batches.
|
cudagraph supporting non-uniform batches.
|
||||||
"""
|
"""
|
||||||
|
has_lora: bool = False
|
||||||
|
"""
|
||||||
|
Whether this batch has active LoRA adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def non_uniform(self) -> "BatchDescriptor":
|
def non_uniform(self) -> "BatchDescriptor":
|
||||||
"""
|
"""
|
||||||
Return a non-uniform version of current batch descriptor.
|
Return a non-uniform version of current batch descriptor.
|
||||||
"""
|
"""
|
||||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
return BatchDescriptor(
|
||||||
|
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compute_sp_num_tokens(
|
def _compute_sp_num_tokens(
|
||||||
|
|||||||
@ -169,6 +169,8 @@ def _lora_shrink(
|
|||||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||||
|
|
||||||
|
output_tensor.zero_()
|
||||||
|
|
||||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
|
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
|
||||||
_get_lora_a_ptr(lora_a_weights, inputs.device)
|
_get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -205,15 +205,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
|
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
|
|
||||||
if buffer is None:
|
assert buffer is None, (
|
||||||
r = lora_b_stacked[0].size(-1)
|
"To minimize overhead, the buffer should be created by "
|
||||||
# We set the buffer to be float32 by default, refer to:
|
".add_lora_linear() instead of being passed in."
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
)
|
||||||
buffer = torch.zeros( # type: ignore
|
r = lora_b_stacked[0].size(-1)
|
||||||
(len(output_slices), x.size(0), r),
|
# We set the buffer to be float32 by default, refer to:
|
||||||
dtype=torch.float32,
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
device=x.device,
|
# Note: buffer is zeroed inside the shrink op
|
||||||
)
|
buffer = torch.empty(
|
||||||
|
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
self.add_shrink(
|
self.add_shrink(
|
||||||
buffer, # type: ignore
|
buffer, # type: ignore
|
||||||
x,
|
x,
|
||||||
@ -260,10 +263,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
r = lora_b_stacked.size(-1)
|
r = lora_b_stacked.size(-1)
|
||||||
if buffer is None:
|
|
||||||
# We set the buffer to be float32 by default, refer to:
|
assert buffer is None, (
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
"To minimize overhead, the buffer should be created by "
|
||||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
".add_lora_linear() instead of being passed in."
|
||||||
|
)
|
||||||
|
# We set the buffer to be float32 by default, refer to:
|
||||||
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
|
# Note: buffer is zeroed inside the shrink op
|
||||||
|
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
lora_shrink(
|
lora_shrink(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.forward_context import BatchDescriptor
|
from vllm.forward_context import BatchDescriptor
|
||||||
@ -67,14 +68,27 @@ class CudagraphDispatcher:
|
|||||||
):
|
):
|
||||||
# This should be called only after attention backend is initialized.
|
# This should be called only after attention backend is initialized.
|
||||||
|
|
||||||
|
# LoRA activation cases to specialize the cuda graphs on
|
||||||
|
if self.vllm_config.lora_config:
|
||||||
|
if self.compilation_config.cudagraph_specialize_lora:
|
||||||
|
lora_cases = [True, False]
|
||||||
|
else:
|
||||||
|
lora_cases = [True]
|
||||||
|
else:
|
||||||
|
lora_cases = [False]
|
||||||
|
|
||||||
# Note: we create all valid keys for cudagraph here but do not
|
# Note: we create all valid keys for cudagraph here but do not
|
||||||
# guarantee all keys would be used. For example, if we allow lazy
|
# guarantee all keys would be used. For example, if we allow lazy
|
||||||
# capturing in future PR, some keys may never be triggered.
|
# capturing in future PR, some keys may never be triggered.
|
||||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||||
for bs in self.compilation_config.cudagraph_capture_sizes:
|
for bs, has_lora in product(
|
||||||
|
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||||
|
):
|
||||||
self.add_cudagraph_key(
|
self.add_cudagraph_key(
|
||||||
cudagraph_mode.mixed_mode(),
|
cudagraph_mode.mixed_mode(),
|
||||||
BatchDescriptor(num_tokens=bs, uniform_decode=False),
|
BatchDescriptor(
|
||||||
|
num_tokens=bs, uniform_decode=False, has_lora=has_lora
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# if decode cudagraph mode is FULL, and we don't already have mixed
|
# if decode cudagraph mode is FULL, and we don't already have mixed
|
||||||
@ -92,10 +106,12 @@ class CudagraphDispatcher:
|
|||||||
for x in self.compilation_config.cudagraph_capture_sizes
|
for x in self.compilation_config.cudagraph_capture_sizes
|
||||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||||
]
|
]
|
||||||
for bs in cudagraph_capture_sizes_for_decode:
|
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||||
self.add_cudagraph_key(
|
self.add_cudagraph_key(
|
||||||
CUDAGraphMode.FULL,
|
CUDAGraphMode.FULL,
|
||||||
BatchDescriptor(num_tokens=bs, uniform_decode=True),
|
BatchDescriptor(
|
||||||
|
num_tokens=bs, uniform_decode=True, has_lora=has_lora
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.keys_initialized = True
|
self.keys_initialized = True
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from itertools import product
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -2469,7 +2470,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
|
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
|
||||||
)
|
)
|
||||||
batch_descriptor = BatchDescriptor(
|
batch_descriptor = BatchDescriptor(
|
||||||
num_tokens=num_input_tokens, uniform_decode=uniform_decode
|
num_tokens=num_input_tokens,
|
||||||
|
uniform_decode=uniform_decode,
|
||||||
|
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode, batch_descriptor = (
|
cudagraph_runtime_mode, batch_descriptor = (
|
||||||
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
|
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
|
||||||
@ -3193,6 +3196,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
create_mixed_batch: bool = False,
|
create_mixed_batch: bool = False,
|
||||||
remove_lora: bool = True,
|
remove_lora: bool = True,
|
||||||
|
activate_lora: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Run a dummy forward pass to warm up/profile run or capture the
|
Run a dummy forward pass to warm up/profile run or capture the
|
||||||
@ -3215,6 +3219,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
create_mixed_batch: If True, create a mixed batch with both decode
|
create_mixed_batch: If True, create a mixed batch with both decode
|
||||||
(1 token) and prefill (multiple tokens) requests.
|
(1 token) and prefill (multiple tokens) requests.
|
||||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||||
|
activate_lora: If False, dummy_run is performed without LoRAs.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
cudagraph_runtime_mode is None
|
cudagraph_runtime_mode is None
|
||||||
@ -3364,7 +3369,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(
|
with self.maybe_dummy_run_with_lora(
|
||||||
self.lora_config, num_scheduled_tokens, remove_lora
|
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
|
||||||
):
|
):
|
||||||
# Make sure padding doesn't exceed max_num_tokens
|
# Make sure padding doesn't exceed max_num_tokens
|
||||||
assert num_tokens_after_padding <= self.max_num_tokens
|
assert num_tokens_after_padding <= self.max_num_tokens
|
||||||
@ -3411,6 +3416,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
BatchDescriptor(
|
BatchDescriptor(
|
||||||
num_tokens=num_tokens_after_padding,
|
num_tokens=num_tokens_after_padding,
|
||||||
uniform_decode=uniform_decode,
|
uniform_decode=uniform_decode,
|
||||||
|
has_lora=activate_lora and self.lora_config is not None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not is_profile
|
if not is_profile
|
||||||
@ -3769,10 +3775,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
assert cudagraph_mode is not None
|
assert cudagraph_mode is not None
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
if self.compilation_config.cudagraph_specialize_lora:
|
||||||
|
lora_cases = [True, False]
|
||||||
|
else:
|
||||||
|
lora_cases = [True]
|
||||||
|
else:
|
||||||
|
lora_cases = [False]
|
||||||
|
|
||||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||||
|
|
||||||
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
|
compilation_cases = list(
|
||||||
|
product(reversed(self.cudagraph_batch_sizes), lora_cases)
|
||||||
|
)
|
||||||
self._capture_cudagraphs(
|
self._capture_cudagraphs(
|
||||||
compilation_cases,
|
compilation_cases,
|
||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
@ -3793,7 +3810,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for x in self.cudagraph_batch_sizes
|
for x in self.cudagraph_batch_sizes
|
||||||
if max_num_tokens >= x >= self.uniform_decode_query_len
|
if max_num_tokens >= x >= self.uniform_decode_query_len
|
||||||
]
|
]
|
||||||
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
|
compilation_cases_decode = list(
|
||||||
|
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
|
||||||
|
)
|
||||||
self._capture_cudagraphs(
|
self._capture_cudagraphs(
|
||||||
compilation_cases=compilation_cases_decode,
|
compilation_cases=compilation_cases_decode,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||||
@ -3823,7 +3842,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def _capture_cudagraphs(
|
def _capture_cudagraphs(
|
||||||
self,
|
self,
|
||||||
compilation_cases: list[int],
|
compilation_cases: list[tuple[int, bool]],
|
||||||
cudagraph_runtime_mode: CUDAGraphMode,
|
cudagraph_runtime_mode: CUDAGraphMode,
|
||||||
uniform_decode: bool,
|
uniform_decode: bool,
|
||||||
):
|
):
|
||||||
@ -3844,7 +3863,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We skip EPLB here since we don't want to record dummy metrics
|
# We skip EPLB here since we don't want to record dummy metrics
|
||||||
for num_tokens in compilation_cases:
|
for num_tokens, activate_lora in compilation_cases:
|
||||||
# We currently only capture ubatched graphs when its a FULL
|
# We currently only capture ubatched graphs when its a FULL
|
||||||
# cudagraph, a uniform decode batch, and the number of tokens
|
# cudagraph, a uniform decode batch, and the number of tokens
|
||||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||||
@ -3875,6 +3894,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
allow_microbatching=allow_microbatching,
|
allow_microbatching=allow_microbatching,
|
||||||
skip_eplb=True,
|
skip_eplb=True,
|
||||||
remove_lora=False,
|
remove_lora=False,
|
||||||
|
activate_lora=activate_lora,
|
||||||
)
|
)
|
||||||
self._dummy_run(
|
self._dummy_run(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@ -3883,6 +3903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
allow_microbatching=allow_microbatching,
|
allow_microbatching=allow_microbatching,
|
||||||
skip_eplb=True,
|
skip_eplb=True,
|
||||||
remove_lora=False,
|
remove_lora=False,
|
||||||
|
activate_lora=activate_lora,
|
||||||
)
|
)
|
||||||
self.maybe_remove_all_loras(self.lora_config)
|
self.maybe_remove_all_loras(self.lora_config)
|
||||||
|
|
||||||
|
|||||||
@ -120,7 +120,10 @@ class LoRAModelRunnerMixin:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def maybe_select_dummy_loras(
|
def maybe_select_dummy_loras(
|
||||||
self, lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray
|
self,
|
||||||
|
lora_config: LoRAConfig | None,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
activate_lora: bool = True,
|
||||||
):
|
):
|
||||||
if lora_config is None:
|
if lora_config is None:
|
||||||
yield
|
yield
|
||||||
@ -133,7 +136,12 @@ class LoRAModelRunnerMixin:
|
|||||||
|
|
||||||
# Make prompt lora mapping
|
# Make prompt lora mapping
|
||||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1
|
if activate_lora:
|
||||||
|
prompt_lora_mapping = (
|
||||||
|
np.arange(num_reqs, dtype=np.int32) % num_loras
|
||||||
|
) + 1
|
||||||
|
else:
|
||||||
|
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
# Make token lora mapping
|
# Make token lora mapping
|
||||||
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
|
||||||
@ -159,11 +167,14 @@ class LoRAModelRunnerMixin:
|
|||||||
self,
|
self,
|
||||||
lora_config: LoRAConfig | None,
|
lora_config: LoRAConfig | None,
|
||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
activate_lora: bool = True,
|
||||||
remove_lora: bool = True,
|
remove_lora: bool = True,
|
||||||
):
|
):
|
||||||
with (
|
with (
|
||||||
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
self.maybe_setup_dummy_loras(lora_config, remove_lora),
|
||||||
self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens),
|
self.maybe_select_dummy_loras(
|
||||||
|
lora_config, num_scheduled_tokens, activate_lora
|
||||||
|
),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user