mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Misc] Add indirection layer for custom ops (#3913)
This commit is contained in:
parent
e42df7227d
commit
e9da5a40c6
@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
|
||||
@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import cache_ops, ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
@ -237,14 +237,14 @@ def test_paged_attention(
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import is_hip
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
@ -80,7 +80,7 @@ def test_copy_blocks(
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
# Run the reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
@ -145,9 +145,9 @@ def test_reshape_and_cache(
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(key_cache, cloned_key_cache)
|
||||
ops.convert_fp8(key_cache, cloned_key_cache)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(value_cache, cloned_value_cache)
|
||||
ops.convert_fp8(value_cache, cloned_value_cache)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
@ -156,14 +156,14 @@ def test_reshape_and_cache(
|
||||
kv_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, kv_scale)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, kv_scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(key_cache, result_key_cache)
|
||||
ops.convert_fp8(key_cache, result_key_cache)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(value_cache, result_value_cache)
|
||||
ops.convert_fp8(value_cache, result_value_cache)
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
@ -251,9 +251,8 @@ def test_swap_blocks(
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
||||
|
||||
for src, dst in block_mapping.items():
|
||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||
@ -291,9 +290,9 @@ def test_fp8_conversion(
|
||||
cache.uniform_(low, high)
|
||||
|
||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||
cache_ops.convert_fp8(cache, cache_fp8)
|
||||
ops.convert_fp8(cache, cache_fp8)
|
||||
|
||||
converted_cache = torch.empty_like(cache)
|
||||
cache_ops.convert_fp8(cache_fp8, converted_cache)
|
||||
ops.convert_fp8(cache_fp8, converted_cache)
|
||||
|
||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
||||
193
vllm/_custom_ops.py
Normal file
193
vllm/_custom_ops.py
Normal file
@ -0,0 +1,193 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops as vllm_cache_ops
|
||||
from vllm._C import ops as vllm_ops
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# activation ops
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.silu_and_mul(out, x)
|
||||
|
||||
|
||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_and_mul(out, x)
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_tanh_and_mul(out, x)
|
||||
|
||||
|
||||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_fast(out, x)
|
||||
|
||||
|
||||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_new(out, x)
|
||||
|
||||
|
||||
# page attention ops
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
|
||||
num_kv_heads, scale, block_tables,
|
||||
context_lens, block_size, max_context_len,
|
||||
alibi_slopes, kv_cache_dtype, kv_scale)
|
||||
|
||||
|
||||
def paged_attention_v2(
|
||||
out: torch.Tensor,
|
||||
exp_sum: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
tmp_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, block_size,
|
||||
max_context_len, alibi_slopes, kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
|
||||
is_neox)
|
||||
|
||||
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
|
||||
cos_sin_cache, is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
|
||||
|
||||
# layer norm ops
|
||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> None:
|
||||
vllm_ops.rms_norm(out, input, weight, epsilon)
|
||||
|
||||
|
||||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
weight: torch.Tensor, epsilon: float) -> None:
|
||||
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
# quantization ops
|
||||
# awq
|
||||
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
||||
thy: int) -> torch.Tensor:
|
||||
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
|
||||
thy)
|
||||
|
||||
|
||||
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
|
||||
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
|
||||
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
|
||||
|
||||
|
||||
# gptq
|
||||
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor, use_exllama: bool,
|
||||
bit: int) -> torch.Tensor:
|
||||
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
|
||||
b_g_idx, use_exllama, bit)
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
|
||||
bit: int) -> None:
|
||||
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
|
||||
|
||||
|
||||
# squeezellm
|
||||
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
|
||||
lookup_table: torch.Tensor) -> None:
|
||||
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
|
||||
|
||||
|
||||
# marlin
|
||||
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
|
||||
size_n: int, size_k: int) -> torch.Tensor:
|
||||
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
# moe
|
||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
block_size: int, sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor) -> None:
|
||||
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
|
||||
sorted_token_ids, experts_ids,
|
||||
num_tokens_post_pad)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, kv_scale)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: Dict[int, int]) -> None:
|
||||
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
|
||||
|
||||
|
||||
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
|
||||
vllm_cache_ops.convert_fp8(output, input)
|
||||
|
||||
|
||||
#TODO: cuda_utils, custom_ar
|
||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops, ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -69,7 +69,7 @@ class PagedAttention:
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
cache_ops.reshape_and_cache(
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@ -199,11 +199,11 @@ class PagedAttention:
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
@ -212,4 +212,4 @@ class PagedAttention:
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
||||
@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -279,10 +279,10 @@ def _generate_random_fp8(
|
||||
#-----|-------------|-------------------
|
||||
# Inf | N/A | s.11111.00
|
||||
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||
from vllm._C import cache_ops
|
||||
from vllm import _custom_ops as ops
|
||||
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
||||
tensor_tmp.uniform_(low, high)
|
||||
cache_ops.convert_fp8(tensor_tmp, tensor)
|
||||
ops.convert_fp8(tensor_tmp, tensor)
|
||||
del tensor_tmp
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user