diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md new file mode 100644 index 0000000000000..fda95ea71891f --- /dev/null +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -0,0 +1,51 @@ +# Multi-LoRA Tuning + +**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations. + +## Tuning Process + +Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py). + +**Step 1** +Define the searching space. An example searching space: + +```python +block_m_range = [16, 32, 64, 128, 256] +block_n_range = [32, 64, 128, 256] +block_k_range = [32, 64, 128, 256] +num_warps_range = [4, 8] +num_stage_range = [2, 3, 4, 5] +num_ctas_range = [1] +split_k_range = [4, 8, 16, 32, 64] +``` + +**Step 2** +Get all hidden_state sizes and num_slices that the target model uses for a specific TP size. + +For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192): + +```python +print(f"x_shape: {x.view(-1, x.shape[-1]).shape}") +print(f"num_sclises: {len(output_slices)}") +for i in range(len(output_slices)): + print(f"a{i} shape: {lora_a_stacked[i].shape}") + print(f"b{i} shape: {lora_b_stacked[i].shape}") +print("y_shape", y.shape) +``` + +**Step 3** +Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes. + +## Config Files + +### File Name + +For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`. + +For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`. + +The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()` + +### Json Structure + +Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]` diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index a7a552b9903d5..c8330455985aa 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547 import torch from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel -from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -201,12 +201,21 @@ def _lora_expand( NUM_SLICES = len(lora_b_weights) # Triton kernel configs. - BLOCK_M = 64 - BLOCK_N = 128 - BLOCK_K = 16 - NUM_WARPS = 4 - NUM_CTAS = 1 - NUM_STAGES = 2 + kernel_config = get_lora_op_configs( + op_type="expand", + max_loras=MAX_LORAS, + batch=M, + hidden_size=MAX_N, + rank=K, + num_slices=NUM_SLICES, + add_inputs=add_inputs, + ) + BLOCK_M = kernel_config["block_m"] + BLOCK_N = kernel_config["block_n"] + BLOCK_K = kernel_config["block_k"] + NUM_WARPS = kernel_config["num_warps"] + NUM_CTAS = kernel_config["num_ctas"] + NUM_STAGES = kernel_config["num_stages"] EVEN_K = K % BLOCK_K == 0 # type: ignore diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 1e7e43e30de78..9cba8f4944486 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -10,7 +10,7 @@ https://arxiv.org/abs/2310.18547 import torch from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel -from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -177,14 +177,21 @@ def _lora_shrink( MAX_LORAS = lora_ids.size(0) # Triton kernel configs - BLOCK_M = 32 - BLOCK_N = 16 - BLOCK_K = 256 if M < 128 else 32 - SPLIT_K = 64 if M < 128 else 8 - NUM_WARPS = 4 - NUM_CTAS = 1 - NUM_STAGES = 2 - + kernel_config = get_lora_op_configs( + "shrink", + max_loras=MAX_LORAS, + batch=M, + hidden_size=K, + rank=N, + num_slices=NUM_SLICES, + ) + BLOCK_M = kernel_config["block_m"] + BLOCK_N = kernel_config["block_n"] + BLOCK_K = kernel_config["block_k"] + SPLIT_K = kernel_config["split_k"] + NUM_WARPS = kernel_config["num_warps"] + NUM_STAGES = kernel_config["num_stages"] + NUM_CTAS = kernel_config["num_ctas"] EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore # TODO (varun): This grid formulation maximizes parallelization at the diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 3a3e8fc8931e8..9ffb6dc3d85e5 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -1,8 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import json +from pathlib import Path +from typing import Any + import torch +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + _LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} @@ -133,3 +143,108 @@ def _get_lora_b_ptr( MAX_N, ) return _LORA_B_PTR_DICT.get(key) + + +@functools.lru_cache +def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None: + user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER + if user_defined_config_folder is not None: + gpu_name = torch.cuda.get_device_name() + gpu_name = gpu_name.replace(" ", "_") + gpu_name = gpu_name.replace("-", "_") + + config_fname = None + if op_type == "shrink": + config_fname = f"{gpu_name}_{op_type.upper()}.json" + else: + assert op_type == "expand" + config_fname = ( + f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json" + ) + + config_path = Path(f"{user_defined_config_folder}/{config_fname}") + if not config_path.exists(): + logger.warning_once(f"No LoRA kernel configs founded in {config_path}") + return None + + # Load json + logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.") + with open(str(config_path)) as f: + config_data = json.load(f) + else: + config_data = None + + return config_data + + +@functools.lru_cache +def get_lora_op_configs( + op_type: str, + max_loras: int, + batch: int, + hidden_size: int, + rank: int, + num_slices: int, + add_inputs: bool | None = None, +) -> dict[str, int | None]: + assert op_type in ["shrink", "expand"] + + # default config + default = {} + if op_type == "shrink": + default = { + "block_m": 32, + "block_n": 16, + "block_k": 256 if batch < 128 else 32, + "split_k": 64 if batch < 128 else 8, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2, + "max_nreg": None, + } + else: + default = { + "block_m": 64, + "block_n": 128, + "block_k": 16, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2, + "max_nreg": None, + } + m = batch + + k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size) + + config_data: Any + config_data = load_lora_op_config(op_type, add_inputs) + if not config_data: + logger.warning_once("Using default LoRA kernel configs") + return default + + # config is structured as config_data[max_loras][num_slices][m][k][n] = {} + # slice by max_loras + config_data = ( + config_data.get(str(max_loras)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))] + ) + # slice by num_slices + config_data = config_data[str(num_slices)] + # slice by m + config_data = ( + config_data.get(str(m)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))] + ) + # slice by k + config_data = ( + config_data.get(str(k)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))] + ) + # slice by n + config_data = ( + config_data.get(str(n)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))] + ) + + assert config_data is not None + return config_data