mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:05:37 +08:00
[Performance][ROCm] Add skinny gemms for unquantized linear on ROCm (#15830)
Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
parent
b9b4746950
commit
188b7f9b8c
@ -678,6 +678,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
set(VLLM_ROCM_EXT_SRC
|
||||
"csrc/rocm/torch_bindings.cpp"
|
||||
"csrc/rocm/skinny_gemms.cu"
|
||||
"csrc/rocm/attention.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
|
||||
@ -2,6 +2,15 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t rows_per_block);
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||
|
||||
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
|
||||
1600
csrc/rocm/skinny_gemms.cu
Normal file
1600
csrc/rocm/skinny_gemms.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,24 @@
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
// vLLM custom ops for rocm
|
||||
|
||||
// Custom gemm op for matrix-vector multiplication
|
||||
rocm_ops.def(
|
||||
"LLMM1(Tensor in_a, Tensor in_b, int rows_per_block) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
|
||||
|
||||
// Custom gemm op for skinny matrix-matrix multiplication
|
||||
rocm_ops.def(
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
|
||||
|
||||
// wvSplitK for fp8
|
||||
rocm_ops.def(
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
|
||||
" Tensor scale_b, int CuCount) -> ()");
|
||||
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
|
||||
|
||||
// Custom attention op
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
|
||||
80
tests/kernels/test_rocm_skinny_gemms.py
Normal file
80
tests/kernels/test_rocm_skinny_gemms.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
|
||||
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0
|
||||
N = [1, 2, 3, 4]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1]) # only test for batch size 1
|
||||
@pytest.mark.parametrize("k", K)
|
||||
@pytest.mark.parametrize("m", M)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
@torch.inference_mode()
|
||||
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
torch.manual_seed(seed)
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda")
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out = torch.matmul(A, B.t())
|
||||
out = ops.LLMM1(B, A, rows_per_block)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
|
||||
@pytest.mark.parametrize("k", K + [9216, 10240, 16384])
|
||||
@pytest.mark.parametrize("m", [8] + M) # m >= 8
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda")
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out = torch.matmul(A, B.t())
|
||||
out = ops.wvSplitK(B, A, cu_count)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", N) # only test for batch size <= 4
|
||||
@pytest.mark.parametrize("k", K[1:] + [14336, 24576, 32768]) # k % 16 == 0
|
||||
@pytest.mark.parametrize("m", M + [28672]) # m >= 16
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda")
|
||||
B = torch.rand(m, k, device="cuda")
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
ref_out = torch._scaled_mm(A,
|
||||
B.t(),
|
||||
out_dtype=dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
@ -1196,6 +1196,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
||||
ssm_states, pad_slot_id)
|
||||
|
||||
|
||||
# ROCm skinny gemms
|
||||
def LLMM1(a: torch.Tensor, b: torch.Tensor,
|
||||
rows_per_block: int) -> torch.Tensor:
|
||||
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
|
||||
|
||||
|
||||
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
|
||||
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
|
||||
|
||||
|
||||
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
cu_count: int) -> torch.Tensor:
|
||||
out = torch.empty((b.shape[0], a.shape[0]),
|
||||
dtype=out_dtype,
|
||||
device=b.device)
|
||||
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
|
||||
return out
|
||||
|
||||
|
||||
# moe
|
||||
def moe_sum(input: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops._moe_C.moe_sum(input, output)
|
||||
|
||||
@ -78,6 +78,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||
@ -550,6 +551,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# use rocm skinny gemms
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Pad the fp8 weights to 256 bytes for ROCm
|
||||
"VLLM_ROCM_FP8_PADDING":
|
||||
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
@ -188,7 +188,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return F.linear(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -17,6 +18,7 @@ TORCH_DEVICE_IDENTITY = None
|
||||
# The condition is determined once as the operations
|
||||
# are time consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
|
||||
and torch.__version__[0:3] >= "2.7"
|
||||
and current_platform.has_device_capability(94))
|
||||
|
||||
|
||||
@ -131,6 +133,159 @@ def maybe_create_device_identity():
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
out_dtype: torch.dtype, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: List, **kwargs) -> torch.Tensor:
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
|
||||
0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
else:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
|
||||
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * scale_b.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
cutlass_fp8_supported: bool, per_tensor_weights: bool,
|
||||
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
|
||||
if cutlass_fp8_supported:
|
||||
return cutlass_w8a8_scaled_mm
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if current_platform.is_rocm():
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
if (use_per_token_if_dynamic and not per_tensor_weights
|
||||
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class Fp8LinearOp:
|
||||
@ -156,7 +311,8 @@ class Fp8LinearOp:
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = config.level < CompilationLevel.PIECEWISE
|
||||
self.output_padding = 17 if pad_output else None
|
||||
self.output_padding = 17 if (
|
||||
pad_output and not current_platform.is_rocm()) else None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -195,18 +351,6 @@ class Fp8LinearOp:
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
else:
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
@ -218,84 +362,21 @@ class Fp8LinearOp:
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.cutlass_fp8_supported, per_tensor_weights,
|
||||
per_tensor_activations, use_per_token_if_dynamic)
|
||||
|
||||
return torch.narrow(output, 0, 0,
|
||||
input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
elif (use_per_token_if_dynamic and not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM):
|
||||
# For now validated on ROCm platform
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
||||
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
# For CUDA platform please validate if the
|
||||
# torch._scaled_mm support rowwise scaled GEMM
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * weight_scale.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
input_2d=input_2d,
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Utility methods for model layers."""
|
||||
from typing import Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def get_token_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
@ -61,3 +65,34 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_unquantized_gemm(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
k = weight.shape[1]
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and \
|
||||
x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and k % 8 == 0 and bias is None)
|
||||
|
||||
if use_skinny is not True:
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
x_view = x.view(-1, x.size(-1))
|
||||
n = x_view.shape[0]
|
||||
m = weight.shape[0]
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
if m > 8 and n < 4:
|
||||
out = ops.wvSplitK(weight, x_view, cu_count)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = ops.LLMM1(weight, x_view, out, 4)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
|
||||
if current_platform.is_rocm():
|
||||
return rocm_unquantized_gemm
|
||||
return torch.nn.functional.linear
|
||||
|
||||
@ -413,6 +413,13 @@ class Platform:
|
||||
self.device_name, key)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
"""
|
||||
Returns the total number of compute units (CU) on single GPU.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@ -312,3 +312,8 @@ class RocmPlatform(Platform):
|
||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||
supported_archs = ['gfx94']
|
||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||
|
||||
@classmethod
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(
|
||||
device_id).multi_processor_count
|
||||
Loading…
x
Reference in New Issue
Block a user