mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 08:09:37 +08:00
[Feature] Batch invariant: Lora (#30097)
Signed-off-by: quanliu <18646313696@163.com>
This commit is contained in:
parent
3e10262356
commit
a37328fc5c
@ -11,9 +11,11 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = vllm_is_batch_invariant()
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
@ -150,7 +152,8 @@ def _get_lora_b_ptr(
|
||||
@functools.lru_cache
|
||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if user_defined_config_folder is not None and not is_batch_invariant:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = gpu_name.replace(" ", "_")
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
@ -203,11 +206,14 @@ def get_lora_op_configs(
|
||||
# default config
|
||||
default = {}
|
||||
if op_type == "shrink":
|
||||
split_k = 64 if batch < 128 else 8
|
||||
if is_batch_invariant:
|
||||
split_k = 1
|
||||
default = {
|
||||
"block_m": 32,
|
||||
"block_n": 16,
|
||||
"block_k": 256 if batch < 128 else 32,
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"split_k": split_k,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"group_size_m": 8,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user