[Feature] Batch invariant: Lora (#30097)

Signed-off-by: quanliu <18646313696@163.com>
This commit is contained in:
quanliu 2025-12-23 10:32:47 +08:00 committed by GitHub
parent 3e10262356
commit a37328fc5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,9 +11,11 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
@ -150,7 +152,8 @@ def _get_lora_b_ptr(
@functools.lru_cache
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
if user_defined_config_folder is not None:
# Avoid optimizing for the batch invariant case. Use default config
if user_defined_config_folder is not None and not is_batch_invariant:
gpu_name = torch.cuda.get_device_name()
gpu_name = gpu_name.replace(" ", "_")
gpu_name = gpu_name.replace("-", "_")
@ -203,11 +206,14 @@ def get_lora_op_configs(
# default config
default = {}
if op_type == "shrink":
split_k = 64 if batch < 128 else 8
if is_batch_invariant:
split_k = 1
default = {
"block_m": 32,
"block_n": 16,
"block_k": 256 if batch < 128 else 32,
"split_k": 64 if batch < 128 else 8,
"split_k": split_k,
"num_warps": 4,
"num_ctas": 1,
"group_size_m": 8,