mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 12:29:07 +08:00
[Feature] Batch invariant: Lora (#30097)
Signed-off-by: quanliu <18646313696@163.com>
This commit is contained in:
parent
3e10262356
commit
a37328fc5c
@ -11,9 +11,11 @@ import torch
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
is_batch_invariant = vllm_is_batch_invariant()
|
||||||
|
|
||||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||||
@ -150,7 +152,8 @@ def _get_lora_b_ptr(
|
|||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||||
if user_defined_config_folder is not None:
|
# Avoid optimizing for the batch invariant case. Use default config
|
||||||
|
if user_defined_config_folder is not None and not is_batch_invariant:
|
||||||
gpu_name = torch.cuda.get_device_name()
|
gpu_name = torch.cuda.get_device_name()
|
||||||
gpu_name = gpu_name.replace(" ", "_")
|
gpu_name = gpu_name.replace(" ", "_")
|
||||||
gpu_name = gpu_name.replace("-", "_")
|
gpu_name = gpu_name.replace("-", "_")
|
||||||
@ -203,11 +206,14 @@ def get_lora_op_configs(
|
|||||||
# default config
|
# default config
|
||||||
default = {}
|
default = {}
|
||||||
if op_type == "shrink":
|
if op_type == "shrink":
|
||||||
|
split_k = 64 if batch < 128 else 8
|
||||||
|
if is_batch_invariant:
|
||||||
|
split_k = 1
|
||||||
default = {
|
default = {
|
||||||
"block_m": 32,
|
"block_m": 32,
|
||||||
"block_n": 16,
|
"block_n": 16,
|
||||||
"block_k": 256 if batch < 128 else 32,
|
"block_k": 256 if batch < 128 else 32,
|
||||||
"split_k": 64 if batch < 128 else 8,
|
"split_k": split_k,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_ctas": 1,
|
"num_ctas": 1,
|
||||||
"group_size_m": 8,
|
"group_size_m": 8,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user