mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 03:55:42 +08:00
[Lora]Load tuned multi-lora kernel configs from json files (#26319)
Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
db1764e4e0
commit
d4d1a6024f
51
vllm/lora/ops/triton_ops/README_TUNING.md
Normal file
51
vllm/lora/ops/triton_ops/README_TUNING.md
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Multi-LoRA Tuning
|
||||||
|
|
||||||
|
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations.
|
||||||
|
|
||||||
|
## Tuning Process
|
||||||
|
|
||||||
|
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
|
||||||
|
|
||||||
|
**Step 1**
|
||||||
|
Define the searching space. An example searching space:
|
||||||
|
|
||||||
|
```python
|
||||||
|
block_m_range = [16, 32, 64, 128, 256]
|
||||||
|
block_n_range = [32, 64, 128, 256]
|
||||||
|
block_k_range = [32, 64, 128, 256]
|
||||||
|
num_warps_range = [4, 8]
|
||||||
|
num_stage_range = [2, 3, 4, 5]
|
||||||
|
num_ctas_range = [1]
|
||||||
|
split_k_range = [4, 8, 16, 32, 64]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2**
|
||||||
|
Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
|
||||||
|
|
||||||
|
For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192):
|
||||||
|
|
||||||
|
```python
|
||||||
|
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
|
||||||
|
print(f"num_sclises: {len(output_slices)}")
|
||||||
|
for i in range(len(output_slices)):
|
||||||
|
print(f"a{i} shape: {lora_a_stacked[i].shape}")
|
||||||
|
print(f"b{i} shape: {lora_b_stacked[i].shape}")
|
||||||
|
print("y_shape", y.shape)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**
|
||||||
|
Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes.
|
||||||
|
|
||||||
|
## Config Files
|
||||||
|
|
||||||
|
### File Name
|
||||||
|
|
||||||
|
For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`.
|
||||||
|
|
||||||
|
For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.
|
||||||
|
|
||||||
|
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`
|
||||||
|
|
||||||
|
### Json Structure
|
||||||
|
|
||||||
|
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`
|
||||||
@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -201,12 +201,21 @@ def _lora_expand(
|
|||||||
NUM_SLICES = len(lora_b_weights)
|
NUM_SLICES = len(lora_b_weights)
|
||||||
|
|
||||||
# Triton kernel configs.
|
# Triton kernel configs.
|
||||||
BLOCK_M = 64
|
kernel_config = get_lora_op_configs(
|
||||||
BLOCK_N = 128
|
op_type="expand",
|
||||||
BLOCK_K = 16
|
max_loras=MAX_LORAS,
|
||||||
NUM_WARPS = 4
|
batch=M,
|
||||||
NUM_CTAS = 1
|
hidden_size=MAX_N,
|
||||||
NUM_STAGES = 2
|
rank=K,
|
||||||
|
num_slices=NUM_SLICES,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
BLOCK_M = kernel_config["block_m"]
|
||||||
|
BLOCK_N = kernel_config["block_n"]
|
||||||
|
BLOCK_K = kernel_config["block_k"]
|
||||||
|
NUM_WARPS = kernel_config["num_warps"]
|
||||||
|
NUM_CTAS = kernel_config["num_ctas"]
|
||||||
|
NUM_STAGES = kernel_config["num_stages"]
|
||||||
|
|
||||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -177,14 +177,21 @@ def _lora_shrink(
|
|||||||
MAX_LORAS = lora_ids.size(0)
|
MAX_LORAS = lora_ids.size(0)
|
||||||
|
|
||||||
# Triton kernel configs
|
# Triton kernel configs
|
||||||
BLOCK_M = 32
|
kernel_config = get_lora_op_configs(
|
||||||
BLOCK_N = 16
|
"shrink",
|
||||||
BLOCK_K = 256 if M < 128 else 32
|
max_loras=MAX_LORAS,
|
||||||
SPLIT_K = 64 if M < 128 else 8
|
batch=M,
|
||||||
NUM_WARPS = 4
|
hidden_size=K,
|
||||||
NUM_CTAS = 1
|
rank=N,
|
||||||
NUM_STAGES = 2
|
num_slices=NUM_SLICES,
|
||||||
|
)
|
||||||
|
BLOCK_M = kernel_config["block_m"]
|
||||||
|
BLOCK_N = kernel_config["block_n"]
|
||||||
|
BLOCK_K = kernel_config["block_k"]
|
||||||
|
SPLIT_K = kernel_config["split_k"]
|
||||||
|
NUM_WARPS = kernel_config["num_warps"]
|
||||||
|
NUM_STAGES = kernel_config["num_stages"]
|
||||||
|
NUM_CTAS = kernel_config["num_ctas"]
|
||||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||||
|
|
||||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||||
|
|||||||
@ -1,8 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||||
|
|
||||||
@ -133,3 +143,108 @@ def _get_lora_b_ptr(
|
|||||||
MAX_N,
|
MAX_N,
|
||||||
)
|
)
|
||||||
return _LORA_B_PTR_DICT.get(key)
|
return _LORA_B_PTR_DICT.get(key)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||||
|
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||||
|
if user_defined_config_folder is not None:
|
||||||
|
gpu_name = torch.cuda.get_device_name()
|
||||||
|
gpu_name = gpu_name.replace(" ", "_")
|
||||||
|
gpu_name = gpu_name.replace("-", "_")
|
||||||
|
|
||||||
|
config_fname = None
|
||||||
|
if op_type == "shrink":
|
||||||
|
config_fname = f"{gpu_name}_{op_type.upper()}.json"
|
||||||
|
else:
|
||||||
|
assert op_type == "expand"
|
||||||
|
config_fname = (
|
||||||
|
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
|
||||||
|
if not config_path.exists():
|
||||||
|
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load json
|
||||||
|
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
|
||||||
|
with open(str(config_path)) as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
else:
|
||||||
|
config_data = None
|
||||||
|
|
||||||
|
return config_data
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def get_lora_op_configs(
|
||||||
|
op_type: str,
|
||||||
|
max_loras: int,
|
||||||
|
batch: int,
|
||||||
|
hidden_size: int,
|
||||||
|
rank: int,
|
||||||
|
num_slices: int,
|
||||||
|
add_inputs: bool | None = None,
|
||||||
|
) -> dict[str, int | None]:
|
||||||
|
assert op_type in ["shrink", "expand"]
|
||||||
|
|
||||||
|
# default config
|
||||||
|
default = {}
|
||||||
|
if op_type == "shrink":
|
||||||
|
default = {
|
||||||
|
"block_m": 32,
|
||||||
|
"block_n": 16,
|
||||||
|
"block_k": 256 if batch < 128 else 32,
|
||||||
|
"split_k": 64 if batch < 128 else 8,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_ctas": 1,
|
||||||
|
"num_stages": 2,
|
||||||
|
"max_nreg": None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
default = {
|
||||||
|
"block_m": 64,
|
||||||
|
"block_n": 128,
|
||||||
|
"block_k": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_ctas": 1,
|
||||||
|
"num_stages": 2,
|
||||||
|
"max_nreg": None,
|
||||||
|
}
|
||||||
|
m = batch
|
||||||
|
|
||||||
|
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
|
||||||
|
|
||||||
|
config_data: Any
|
||||||
|
config_data = load_lora_op_config(op_type, add_inputs)
|
||||||
|
if not config_data:
|
||||||
|
logger.warning_once("Using default LoRA kernel configs")
|
||||||
|
return default
|
||||||
|
|
||||||
|
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
|
||||||
|
# slice by max_loras
|
||||||
|
config_data = (
|
||||||
|
config_data.get(str(max_loras))
|
||||||
|
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
|
||||||
|
)
|
||||||
|
# slice by num_slices
|
||||||
|
config_data = config_data[str(num_slices)]
|
||||||
|
# slice by m
|
||||||
|
config_data = (
|
||||||
|
config_data.get(str(m))
|
||||||
|
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
|
||||||
|
)
|
||||||
|
# slice by k
|
||||||
|
config_data = (
|
||||||
|
config_data.get(str(k))
|
||||||
|
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
|
||||||
|
)
|
||||||
|
# slice by n
|
||||||
|
config_data = (
|
||||||
|
config_data.get(str(n))
|
||||||
|
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config_data is not None
|
||||||
|
return config_data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user