From e9da5a40c63ce7f8a85438d3c7d919b46e7939f5 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 11 Apr 2024 03:26:07 +0000 Subject: [PATCH] [Misc] Add indirection layer for custom ops (#3913) --- .../kernels/benchmark_paged_attention.py | 2 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 25 ++- vllm/_custom_ops.py | 193 ++++++++++++++++++ vllm/attention/ops/paged_attn.py | 10 +- vllm/model_executor/layers/activation.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- .../model_executor/layers/quantization/awq.py | 2 +- .../layers/quantization/gptq.py | 2 +- .../layers/quantization/marlin.py | 2 +- .../layers/quantization/squeezellm.py | 2 +- .../model_executor/layers/rotary_embedding.py | 2 +- vllm/utils.py | 4 +- 14 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 vllm/_custom_ops.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f71d1fcaaef5..5c3650fa72d1 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,7 +5,7 @@ from typing import Optional import torch -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 03ea72924921..9b1f3e30b6dc 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -237,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4141aacafd0b..d1051fd7e2f4 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -4,7 +4,7 @@ from typing import Tuple import pytest import torch -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -80,7 +80,7 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + ops.copy_blocks(key_caches, value_caches, block_mapping) # Run the reference implementation. for src, dsts in block_mapping.items(): @@ -145,9 +145,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(key_cache, cloned_key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(value_cache, cloned_value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -156,14 +156,14 @@ def test_reshape_and_cache( kv_scale = 1.0 # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(key_cache, result_key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -251,9 +251,8 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) for src, dst in block_mapping.items(): assert torch.allclose(src_key_caches_clone[src].cpu(), @@ -291,9 +290,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - cache_ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache, cache_fp8) converted_cache = torch.empty_like(cache) - cache_ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(cache_fp8, converted_cache) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 000000000000..a0837a20875f --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,193 @@ +from typing import Dict, Optional + +import torch + +try: + from vllm._C import cache_ops as vllm_cache_ops + from vllm._C import ops as vllm_ops +except ImportError: + pass + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_tanh_and_mul(out, x) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_new(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, + context_lens, block_size, max_context_len, + alibi_slopes, kv_cache_dtype, kv_scale) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + vllm_ops.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + + +# squeezellm +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) + + +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: Dict[int, int]) -> None: + vllm_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: + vllm_cache_ops.convert_fp8(output, input) + + +#TODO: cuda_utils, custom_ar diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2d918491d657..cd0690a4ba95 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple import torch -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -69,7 +69,7 @@ class PagedAttention: kv_cache_dtype: str, kv_scale: float, ) -> None: - cache_ops.reshape_and_cache( + ops.reshape_and_cache( key, value, key_cache, @@ -199,11 +199,11 @@ class PagedAttention: ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -212,4 +212,4 @@ class PagedAttention: ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 6786c48e0cab..baf1d4f26618 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ec09f0cd4c2..377b6588dbf4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,7 @@ import torch import triton import triton.language as tl -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils import is_hip diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5a..a6619714b8aa 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops class RMSNorm(nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..daea5ac73e42 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..757ab1af8392 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf..a6482c059cc4 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1..bb295df2acc3 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d80e73bbe39e..eb8d5f6dfb2a 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index 8ba03333d3b6..8ab8927512cc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -279,10 +279,10 @@ def _generate_random_fp8( #-----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm._C import cache_ops + from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp