mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-04 19:17:04 +08:00
Merge d3fc0729f78dc982ab7006ceabd2e36c935cb193 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
b2120877bc
@ -26,14 +26,13 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
GroupShape,
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ...utils import has_module_attribute, multi_gpu_test
|
||||
from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
|
||||
@ -75,25 +74,32 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.weight = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight[i],
|
||||
self.wscale[i],
|
||||
input_scale=self.input_scale[i],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@ -101,23 +107,18 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
z2 = self.fp8_linear_layers[0](y)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
z3 = self.fp8_linear_layers[1](y2)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
z4 = self.fp8_linear_layers[2](y3)
|
||||
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
@ -129,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
if self.fp8_linear.quant_fp8.enabled()
|
||||
if self.fp8_linear_layers[0].is_quant_fp8_enabled()
|
||||
else torch.ops.aten.reciprocal.default,
|
||||
]
|
||||
|
||||
|
||||
@ -27,12 +27,13 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from ...utils import TestFP8Layer, multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -93,6 +94,8 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
@ -106,37 +109,32 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
self.quant_key, self.quant_key, self.w[i], self.wscale[i], self.scale[i]
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
z2 = self.fp8_linear_layers[0](y)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
z3 = self.fp8_linear_layers[1](y2)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
z4 = self.fp8_linear_layers[2](y3)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
@ -159,7 +157,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
return [
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
]
|
||||
elif self.fp8_linear.quant_fp8.enabled():
|
||||
elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
|
||||
return [
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
]
|
||||
|
||||
@ -20,11 +20,13 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
TEST_FP8 = current_platform.supports_fp8()
|
||||
@ -32,24 +34,27 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestSiluMul(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||
if TEST_FP8:
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight,
|
||||
self.weight_scale,
|
||||
self.input_scale,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
if TEST_FP8:
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
return self.fp8_linear(y)
|
||||
else:
|
||||
return y
|
||||
|
||||
@ -67,6 +72,8 @@ class TestSiluMul(torch.nn.Module):
|
||||
|
||||
|
||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -81,11 +88,18 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
if TEST_FP8:
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.weight = (
|
||||
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
)
|
||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight,
|
||||
self.weight_scale,
|
||||
self.input_scale,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
# Reshape input
|
||||
@ -99,13 +113,9 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
norm_output, residual_output = self.norm(mm, residual)
|
||||
|
||||
if TEST_FP8:
|
||||
self.input_scale = self.input_scale.to(norm_output.device)
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
)
|
||||
fp8_linear_result = self.fp8_linear(norm_output)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.config
|
||||
import vllm.plugins
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
@ -20,8 +21,22 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
|
||||
FlashInferScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
RowWiseTorchScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
ROCmScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from ..utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -45,157 +59,260 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
# Kernel and group_shape combinations: (kernel, group_shape)
|
||||
# CUDA kernels
|
||||
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# FlashInferScaledMMLinearKernel supports both per-tensor and per-token
|
||||
(FlashInferScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
(FlashInferScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
|
||||
(CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
(CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# PerTensorTorchScaledMMLinearKernel only supports per-tensor
|
||||
(PerTensorTorchScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# ChannelWiseTorchScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
# ROCm kernels
|
||||
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# ROCmScaledMMLinearKernel supports both per-tensor and per-token
|
||||
(ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
(ROCmScaledMMLinearKernel, GroupShape.PER_TENSOR),
|
||||
# RowWiseTorchScaledMMLinearKernel only supports per-token
|
||||
(RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# ChannelWiseTorchScaledMMLinearKernel only supports per-token
|
||||
(ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN),
|
||||
# Blockwise group shapes (no kernel abstraction)
|
||||
(None, GroupShape(1, 128)),
|
||||
(None, GroupShape(1, 64)),
|
||||
]
|
||||
|
||||
KERNEL_GROUPSHAPE_COMBINATIONS = (
|
||||
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
|
||||
if current_platform.is_cuda()
|
||||
else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
|
||||
)
|
||||
|
||||
# For Aiter tests we toggle use_aiter_quant_op
|
||||
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
|
||||
# Per-token with ROCmScaledMMLinearKernel
|
||||
(ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||
(ROCmScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||
# Per-token with RowWiseTorchScaledMMLinearKernel
|
||||
(RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||
(RowWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||
# Per-token with ChannelWiseTorchScaledMMLinearKernel
|
||||
(ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
|
||||
(ChannelWiseTorchScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
|
||||
# Blockwise (no kernel abstraction)
|
||||
(None, GroupShape(1, 128), True),
|
||||
]
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
group_shape: GroupShape,
|
||||
use_aiter: bool = False,
|
||||
cuda_force_torch: bool = False,
|
||||
use_aiter_quant_op: bool = True,
|
||||
use_aiter_fusion: bool = False,
|
||||
use_aiter_quant: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.use_aiter = use_aiter
|
||||
self.use_aiter_quant_op = use_aiter_quant_op
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.fp8_linear_layers: list[torch.nn.Module]
|
||||
self.group_shape = group_shape
|
||||
self.enable_quant_fp8_custom_op = None # Will be set later if applicable
|
||||
|
||||
self.use_aiter_quant_op = use_aiter_quant
|
||||
self.use_aiter_fusion = use_aiter_fusion
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
|
||||
# Setup quantization scale descriptor
|
||||
static = group_shape == GroupShape.PER_TENSOR and not use_aiter
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
# Determine if blockwise based on group_shape
|
||||
is_blockwise = group_shape.is_per_group()
|
||||
|
||||
# Setup scales
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
if is_blockwise:
|
||||
self._init_blockwise(
|
||||
hidden_size, group_shape, use_aiter_fusion, use_aiter_quant
|
||||
)
|
||||
else:
|
||||
self.scale = [None for _ in range(3)]
|
||||
self._init_nonblockwise(
|
||||
hidden_size, group_shape, force_kernel, use_aiter_quant
|
||||
)
|
||||
|
||||
# Setup weights
|
||||
def _init_nonblockwise(
|
||||
self,
|
||||
hidden_size: int,
|
||||
group_shape: GroupShape,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
use_aiter_quant: bool,
|
||||
):
|
||||
"""Initialize non-blockwise (per-tensor/per-token) FP8 layers."""
|
||||
is_static = group_shape == GroupShape.PER_TENSOR
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
|
||||
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
)
|
||||
self.weight_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
|
||||
)
|
||||
|
||||
# Setup weight scales
|
||||
wscale_shape = (1,) if group_shape.is_per_tensor() else (hidden_size, 1)
|
||||
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
self.act_scale = (
|
||||
[torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
if is_static
|
||||
else [None for _ in range(3)]
|
||||
)
|
||||
|
||||
# Initialize weights (transposed for non-blockwise)
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
# Setup FP8 linear layers with kernel abstraction
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
self.activation_quant_key,
|
||||
self.weight_quant_key,
|
||||
self.w[i],
|
||||
self.wscale[i],
|
||||
input_scale=self.act_scale[i],
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
# Enable aiter quantization if requested
|
||||
for layer in self.fp8_linear_layers:
|
||||
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
|
||||
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
def _init_blockwise(
|
||||
self,
|
||||
hidden_size: int,
|
||||
group_shape: GroupShape,
|
||||
use_aiter_fusion: bool,
|
||||
use_aiter_quant: bool,
|
||||
):
|
||||
"""Initialize blockwise FP8 layers."""
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
)
|
||||
|
||||
# Setup weight scales (for blockwise quantization)
|
||||
# Use aiter block size if aiter fusion is enabled
|
||||
scale_size = (
|
||||
(hidden_size + 128 - 1) // 128
|
||||
if use_aiter_fusion
|
||||
else hidden_size // group_shape[1]
|
||||
)
|
||||
wscale_shape = (scale_size, scale_size)
|
||||
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
# Initialize weights (transposed if using aiter, otherwise not)
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
||||
]
|
||||
if not group_shape.is_per_group() or use_aiter:
|
||||
self.w = [self.w[0].t() for _ in range(3)]
|
||||
if use_aiter_fusion:
|
||||
self.w = [w.t() for w in self.w]
|
||||
|
||||
# Setup weight scales
|
||||
if group_shape.is_per_group():
|
||||
scale_size = (
|
||||
(hidden_size + 128 - 1) // 128
|
||||
if use_aiter
|
||||
else hidden_size // group_shape[1]
|
||||
)
|
||||
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
|
||||
else:
|
||||
wscale_shape = (1,)
|
||||
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
# Setup FP8 linear operation
|
||||
is_per_group = group_shape.is_per_group()
|
||||
if is_per_group and use_aiter:
|
||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(128, 128),
|
||||
act_quant_group_shape=group_shape,
|
||||
use_aiter_and_is_supported=use_aiter_quant_op,
|
||||
)
|
||||
# AITER blockwise doesn't use enable_quant_fp8_custom_op
|
||||
elif is_per_group:
|
||||
self.fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
self.fp8_linear_layers = [
|
||||
TestBlockFP8Layer(
|
||||
group_shape=group_shape,
|
||||
weight=self.w[i],
|
||||
weight_scale=self.wscale[i],
|
||||
input_scale=None, # Dynamic quantization for blockwise
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
use_aiter_and_is_supported=use_aiter_quant,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled()
|
||||
elif use_aiter:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
else:
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.enable_quant_fp8_custom_op = (
|
||||
False
|
||||
if use_aiter_quant
|
||||
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
x2 = self.fp8_linear_layers[0](y)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
x3 = self.fp8_linear_layers[1](y2)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
x4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = self.fp8_linear_layers[2](y3)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
if (
|
||||
self.use_aiter
|
||||
and self.group_shape.is_per_group()
|
||||
and current_platform.is_fp8_fnuz()
|
||||
):
|
||||
return [rocm_aiter_ops.get_group_quant_op()]
|
||||
if self.use_aiter and self.group_shape.is_per_group():
|
||||
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
|
||||
if self.use_aiter and self.use_aiter_quant_op:
|
||||
return [rocm_aiter_ops.get_per_token_quant_op()]
|
||||
if self.use_aiter:
|
||||
return [QUANT_OPS[self.quant_key]]
|
||||
if self.enable_quant_fp8_custom_op:
|
||||
return [QUANT_OPS[self.quant_key]]
|
||||
return [torch.ops.aten.reciprocal]
|
||||
if self.group_shape.is_per_group():
|
||||
# Blockwise path
|
||||
if self.use_aiter_fusion and self.use_aiter_quant_op:
|
||||
return [rocm_aiter_ops.get_group_quant_op()]
|
||||
if self.use_aiter_fusion:
|
||||
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
|
||||
else:
|
||||
if self.use_aiter_quant_op:
|
||||
return [rocm_aiter_ops.get_per_token_quant_op()]
|
||||
|
||||
# Common path
|
||||
return (
|
||||
[QUANT_OPS[self.activation_quant_key]]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else [torch.ops.aten.reciprocal]
|
||||
)
|
||||
|
||||
def ops_in_model_after(self):
|
||||
if self.use_aiter and self.group_shape.is_per_group():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
)
|
||||
if self.use_aiter_fusion:
|
||||
if self.group_shape.is_per_group():
|
||||
# Blockwise aiter fusion
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
)
|
||||
|
||||
return [
|
||||
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
|
||||
AiterRMSFp8GroupQuantPattern.FUSED_OP,
|
||||
]
|
||||
if self.use_aiter:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
)
|
||||
return [
|
||||
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
|
||||
AiterRMSFp8GroupQuantPattern.FUSED_OP,
|
||||
]
|
||||
else:
|
||||
# Per-token aiter fusion
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
)
|
||||
|
||||
return [
|
||||
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
|
||||
AiterRMSNormDynamicQuantPattern.FUSED_OP,
|
||||
]
|
||||
return [
|
||||
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
|
||||
AiterRMSNormDynamicQuantPattern.FUSED_OP,
|
||||
]
|
||||
|
||||
# Regular fusion
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
|
||||
]
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
@ -206,14 +323,6 @@ class TestModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
GROUP_SHAPES = [
|
||||
GroupShape.PER_TOKEN,
|
||||
GroupShape.PER_TENSOR,
|
||||
GroupShape(1, 128),
|
||||
GroupShape(1, 64),
|
||||
]
|
||||
|
||||
|
||||
def _run_fusion_test(
|
||||
model,
|
||||
fusion_pass,
|
||||
@ -259,14 +368,9 @@ def _run_fusion_test(
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||
@pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize(
|
||||
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
@ -275,11 +379,12 @@ def test_fusion_rmsnorm_quant(
|
||||
hidden_size,
|
||||
num_tokens,
|
||||
eps,
|
||||
group_shape,
|
||||
kernel_groupshape,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
cuda_force_torch,
|
||||
):
|
||||
force_kernel, group_shape = kernel_groupshape
|
||||
|
||||
if not enable_quant_fp8_custom_op and group_shape.is_per_group():
|
||||
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
|
||||
|
||||
@ -310,15 +415,16 @@ def test_fusion_rmsnorm_quant(
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity()
|
||||
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
|
||||
model = TestModel(
|
||||
hidden_size=hidden_size,
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
use_aiter=False,
|
||||
cuda_force_torch=cuda_force_torch,
|
||||
use_aiter_fusion=False,
|
||||
use_aiter_quant=False,
|
||||
)
|
||||
|
||||
backend, _ = _run_fusion_test(
|
||||
@ -339,19 +445,12 @@ def test_fusion_rmsnorm_quant(
|
||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||
|
||||
|
||||
GROUP_SHAPE_QUANT_OPS_MATCHS = [
|
||||
(GroupShape.PER_TOKEN, True),
|
||||
(GroupShape.PER_TOKEN, False),
|
||||
(GroupShape(1, 128), True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize(
|
||||
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS
|
||||
"kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
(not current_platform.is_rocm() or not IS_AITER_FOUND),
|
||||
@ -362,10 +461,10 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
hidden_size: int,
|
||||
num_tokens: int,
|
||||
eps: float,
|
||||
group_shape: GroupShape,
|
||||
use_aiter_quant_op: bool,
|
||||
kernel_groupshape_quant: tuple,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
@ -379,20 +478,22 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity()
|
||||
|
||||
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
|
||||
|
||||
model = TestModel(
|
||||
hidden_size=hidden_size,
|
||||
eps=eps,
|
||||
force_kernel=force_kernel,
|
||||
group_shape=group_shape,
|
||||
use_aiter=True,
|
||||
use_aiter_quant_op=use_aiter_quant_op,
|
||||
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
|
||||
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
|
||||
)
|
||||
|
||||
_run_fusion_test(
|
||||
|
||||
@ -34,11 +34,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
@ -171,11 +172,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.quant_key.scale.static,
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape,
|
||||
)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w",
|
||||
@ -187,16 +183,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
},
|
||||
)
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.w["weight"],
|
||||
self.w["wscale"],
|
||||
self.w["scale"],
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(
|
||||
input=attn_output,
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"],
|
||||
)
|
||||
return self.fp8_linear(attn_output)
|
||||
|
||||
|
||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
|
||||
@ -31,13 +31,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -49,25 +45,30 @@ def is_nvfp4_supported():
|
||||
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
self.fp8_linear = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight,
|
||||
self.weight_scale,
|
||||
self.input_scale,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
x2 = self.fp8_linear(y)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@ -198,7 +199,6 @@ def test_fusion_silu_and_mul_quant(
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
maybe_create_device_identity()
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
|
||||
104
tests/utils.py
104
tests/utils.py
@ -42,6 +42,17 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
@ -1311,3 +1322,96 @@ def flat_product(*iterables: Iterable[Any]):
|
||||
for element in itertools.product(*iterables):
|
||||
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
|
||||
yield tuple(itertools.chain(*normalized))
|
||||
|
||||
|
||||
class TestFP8Layer(torch.nn.Module):
|
||||
"""
|
||||
Test helper class for evaluating FP8 linear operations with quantization.
|
||||
|
||||
It supports configurable activation and weight quantization parameters,
|
||||
and provides a forward method that applies the FP8 linear transformation
|
||||
with optional bias.
|
||||
|
||||
Args:
|
||||
activation_quant_key (QuantKey): Key for activation quantization configuration.
|
||||
weight_quant_key (QuantKey): Key for weight quantization configuration.
|
||||
weight (torch.Tensor): Weight tensor for linear transformation.
|
||||
weight_scale (torch.Tensor): Per-tensor or per-group scale for weights.
|
||||
input_scale (torch.Tensor, optional): Scale tensor for input quantization.
|
||||
Defaults to None.
|
||||
out_dtype (torch.dtype, optional): Output tensor data type.
|
||||
Defaults to torch.get_default_dtype().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight_scale = weight_scale
|
||||
self.weight = weight
|
||||
self.input_scale = input_scale
|
||||
self.input_scale_ub = None
|
||||
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
|
||||
self.kernel = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=out_dtype,
|
||||
force_kernel=force_kernel,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.kernel.quant_fp8.enabled()
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(self, y, bias)
|
||||
|
||||
|
||||
class TestBlockFP8Layer:
|
||||
"""
|
||||
Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface.
|
||||
|
||||
This is a workaround until W8A8BlockFp8LinearOp implements
|
||||
ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_shape: GroupShape,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
cutlass_block_fp8_supported: bool = False,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
):
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
self.weight = weight
|
||||
self.weight_scale = weight_scale
|
||||
self.input_scale = input_scale
|
||||
|
||||
def __call__(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.linear_op.apply(
|
||||
input=y,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
input_scale=self.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.linear_op.input_quant_op.enabled()
|
||||
|
||||
@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_weight_tensor_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
@ -42,6 +49,18 @@ strategy_to_parameter_type = {
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
STATIC_QUANT = True
|
||||
DYNAMIC_QUANT = False
|
||||
activation_quant_key_mapping = {
|
||||
STATIC_QUANT: kFp8StaticTensorSym,
|
||||
DYNAMIC_QUANT: kFp8DynamicTokenSym,
|
||||
}
|
||||
weight_quant_key_mapping = {
|
||||
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym,
|
||||
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
|
||||
}
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||
@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
if self.weight_block_size is not None:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
self.act_q_group_shape = (
|
||||
GroupShape.PER_TENSOR
|
||||
if is_static_input_scheme
|
||||
else GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
assert not self.is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
|
||||
weight_quant_key = weight_quant_key_mapping[self.strategy]
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.weight_block_size = None
|
||||
@ -134,6 +146,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
layer.input_scale_ub = None
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
@ -190,11 +204,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@ -25,8 +24,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj",
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
|
||||
@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=None,
|
||||
input_scale_ub=layer.input_scale_ub,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -45,6 +45,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
@ -78,13 +81,14 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
all_close_1d,
|
||||
cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
@ -431,8 +435,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
self.activation_quant_key = kFp8DynamicTokenSym
|
||||
elif self.act_q_static:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
self.activation_quant_key = kFp8StaticTensorSym
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
self.activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
@ -444,9 +453,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
@ -459,8 +470,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
@ -525,6 +534,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
layer.input_scale_ub = None
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
@ -699,14 +709,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@ -2,48 +2,73 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
is_channelwise: bool
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
|
||||
is_static_input_scheme: bool
|
||||
is_channelwise: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_key: QuantKey
|
||||
activation_quant_key: QuantKey
|
||||
out_dtype: torch.dtype | None
|
||||
|
||||
|
||||
_FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_scale_ub,
|
||||
]
|
||||
_Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
|
||||
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: ScaledMMLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
i_s_param_name: str,
|
||||
i_zp_param_name: str,
|
||||
azp_adj_param_name: str,
|
||||
) -> None:
|
||||
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||
assert self.can_implement(c)
|
||||
assert self.is_supported()
|
||||
assert self.is_platform_supported()
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
self.i_s_name = i_s_param_name
|
||||
self.i_zp_name = i_zp_param_name
|
||||
self.azp_adj_name = azp_adj_param_name
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@ -58,19 +83,109 @@ class ScaledMMLinearKernel(ABC):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.i_s_name),
|
||||
getattr(layer, self.i_zp_name),
|
||||
getattr(layer, self.azp_adj_name),
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer) -> _ParamsT:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FP8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
|
||||
):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
act_scale_descriptor = c.activation_quant_key.scale
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=act_scale_descriptor.static,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_ouput_padding(),
|
||||
)
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def _get_layer_params(self, layer) -> _FP8ParamsT:
|
||||
w, w_s, x_s, x_s_ub = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, x_s),
|
||||
getattr(layer, x_s_ub),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm_func = self.get_scaled_mm_func()
|
||||
quant_fp8 = self.quant_fp8
|
||||
fp8_dtype = self.fp8_dtype
|
||||
maybe_out_dtype = self.config.out_dtype
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != fp8_dtype:
|
||||
x_2d_q, x_s = quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
x_s_ub,
|
||||
)
|
||||
return scaled_mm_func(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
||||
):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
def _get_layer_params(self, layer) -> _Int8ParamsT:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w_q),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, i_s),
|
||||
getattr(layer, i_zp),
|
||||
getattr(layer, azp_adj),
|
||||
)
|
||||
|
||||
@ -2,7 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
)
|
||||
@ -10,9 +14,25 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
|
||||
FlashInferScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
RowWiseTorchScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
ROCmScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
@ -22,60 +42,206 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
FlashInferScaledMMLinearKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ROCmScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
RowWiseTorchScaledMMLinearKernel,
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
def is_supported_and_can_implement_kernel(
|
||||
kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None
|
||||
) -> tuple[bool, str]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
return False, f" {kernel.__name__} disabled by environment variable"
|
||||
|
||||
platform_supported, requires_platform = kernel.is_platform_supported()
|
||||
if not platform_supported:
|
||||
return (
|
||||
False,
|
||||
f"{kernel.__name__} is not supported as it requires {requires_platform}.",
|
||||
)
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute cability.
|
||||
if compute_capability is not None:
|
||||
kernel_min_capability = kernel.get_min_capability()
|
||||
if (
|
||||
kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability
|
||||
):
|
||||
return (
|
||||
False,
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel_min_capability}, current compute capability "
|
||||
f"is {compute_capability}",
|
||||
)
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if not can_implement:
|
||||
return (
|
||||
False,
|
||||
f" {kernel.__name__} cannot be implement because: {failure_reason}",
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
config: _KernelConfigT,
|
||||
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
|
||||
compute_capability: int | None = None,
|
||||
force_kernel: type[_KernelT] | None = None,
|
||||
) -> type[_KernelT]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
Choose a _KernelT that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||
config (_KernelConfigT): Description of the linear layer
|
||||
to be implemented.
|
||||
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
|
||||
dictionary of platforms and their list list of possible kernels.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
|
||||
the possible_kernels if it can be implemented. If None, it will only try the
|
||||
possible kernels.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
_KernelT: Chosen kernel.
|
||||
"""
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
|
||||
continue
|
||||
failure_reason_list = []
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute capability.
|
||||
is_supported, reason = kernel.is_supported(compute_capability)
|
||||
if not is_supported:
|
||||
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||
continue
|
||||
if force_kernel is not None:
|
||||
can_implement, failure_reason = is_supported_and_can_implement_kernel(
|
||||
force_kernel, config, compute_capability
|
||||
)
|
||||
if can_implement:
|
||||
return force_kernel
|
||||
|
||||
can_implement, reason = kernel.can_implement(config)
|
||||
if not can_implement:
|
||||
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||
continue
|
||||
logger.info_once(
|
||||
"Tried to force %s, but the kernel couldn't be implemented",
|
||||
force_kernel.__name__,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel
|
||||
for kernel in possible_kernels[current_platform._enum]:
|
||||
is_supported_and_can_implement, failure_reason = (
|
||||
is_supported_and_can_implement_kernel(kernel, config, compute_capability)
|
||||
)
|
||||
if is_supported_and_can_implement:
|
||||
return kernel
|
||||
failure_reason_list.append(failure_reason)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
|
||||
)
|
||||
|
||||
|
||||
def init_fp8_linear_kernel(
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
out_dtype: torch.dtype,
|
||||
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
|
||||
module_name: str | None = None,
|
||||
) -> FP8ScaledMMLinearKernel:
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
weight_quant_key=weight_quant_key,
|
||||
activation_quant_key=activation_quant_key,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
|
||||
)
|
||||
|
||||
if module_name:
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
|
||||
)
|
||||
|
||||
|
||||
def init_int8_linear_kernel(
|
||||
is_channelwise: bool,
|
||||
is_static_input_scheme: bool,
|
||||
input_symmetric: bool,
|
||||
module_name: str,
|
||||
) -> Int8ScaledMMLinearKernel:
|
||||
config = Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=is_channelwise,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
config,
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__class__.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
config,
|
||||
layer_param_names=[
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
],
|
||||
)
|
||||
|
||||
@ -9,27 +9,22 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "currently supported on non-ROCm platform.",
|
||||
)
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc.major * 10 + _cc.minor
|
||||
if compute_capability is not None and compute_capability < 90:
|
||||
return False, f"requires capability 90, got {compute_capability}"
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "ROCm"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
except Exception:
|
||||
@ -48,10 +43,6 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return (
|
||||
False,
|
||||
@ -59,9 +50,6 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -78,7 +66,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
|
||||
@ -14,24 +14,34 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
def get_min_capability(cls) -> int:
|
||||
# current_platform.get_device_capability() returns None
|
||||
# so the check will be ignored
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def is_platform_supported(
|
||||
cls,
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Requires CPU."
|
||||
return False, "CPU"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q_name, _, _, _, _ = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
dtype = weight.dtype
|
||||
N, K = weight.size()
|
||||
if (
|
||||
@ -49,10 +59,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Transpose to [K, N] for convenience
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
@ -61,28 +72,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
@ -92,20 +102,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
azp = (
|
||||
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
||||
# term for dynamic quantization. And s_b should be folded into the
|
||||
# term. Such as:
|
||||
@ -113,38 +119,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
azp_adj = azp_adj * weight_scale.squeeze()
|
||||
setattr(
|
||||
layer,
|
||||
self.azp_adj_name,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
||||
weight,
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, w_s_name),
|
||||
torch.get_default_dtype(),
|
||||
getattr(layer, self.i_s_name) is None,
|
||||
getattr(layer, i_s_name) is None,
|
||||
not self.config.input_symmetric,
|
||||
32,
|
||||
)
|
||||
# weight is prepacked and maintained by the dnnl_handler,
|
||||
# release the original weight
|
||||
setattr(layer, self.w_q_name, None)
|
||||
setattr(layer, w_q_name, None)
|
||||
del weight
|
||||
|
||||
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||
# WEIGHT
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
)
|
||||
|
||||
if layer.bias is not None:
|
||||
@ -156,19 +161,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# WEIGHT SCALE
|
||||
# CPU SGL kernels only support per-channel.
|
||||
# For per-tensor quant, convert to the per-channel case.
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -187,7 +188,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
@ -209,7 +210,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
x,
|
||||
w_q,
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@ -11,35 +13,51 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def cutlass_w8a8_scaled_mm_fp8(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Requires CUDA."
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc.major * 10 + _cc.minor
|
||||
if compute_capability is not None and compute_capability < 75:
|
||||
return False, f"requires capability 75, got {compute_capability}"
|
||||
return False, "CUDA"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
@ -48,28 +66,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
if config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
if config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
@ -79,38 +97,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
if not config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
if config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
self.azp_adj_name,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@ -118,7 +130,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
@ -145,3 +157,21 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return ops.cutlass_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUDA"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return cutlass_w8a8_scaled_mm_fp8
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_scaled_fp8_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
|
||||
|
||||
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUDA"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not has_flashinfer():
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel requires "
|
||||
+ "FlashInfer to be installed.",
|
||||
)
|
||||
if not has_flashinfer():
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel requires "
|
||||
+ "FlashInfer to be installed.",
|
||||
)
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
False,
|
||||
"FlashInferScaledMMLinearKernel requires "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch._scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_row_wise_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Note:
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.t(),
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
|
||||
def torch_channelwise_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
|
||||
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
scale_a=dummy_tensor,
|
||||
scale_b=dummy_tensor,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||
x_scale = torch.narrow(As, 0, 0, output_shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * Bs.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
|
||||
|
||||
class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
"""
|
||||
Base class for FP8 linear kernels using Torch.
|
||||
Each subclass represents a kernel variant for
|
||||
specific device capabilities and torch versions,
|
||||
so we split them up and implement
|
||||
get_min_capability() separately for each.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_platform_supported(
|
||||
cls,
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda_alike():
|
||||
return False, "ROCm or CUDA"
|
||||
return True, None
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
output_padding = 17 if pad_output else None
|
||||
return output_padding
|
||||
|
||||
|
||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
False,
|
||||
"PerTensorTorchScaledMMLinearKernel requires "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
|
||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 94
|
||||
|
||||
@classmethod
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "ROCm"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if c.out_dtype == torch.float16:
|
||||
# hipblaslt rowwise _scaled_mm only supports BFloat16
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel only supports BFloat16.",
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel cannot be used with "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
|
||||
if not version.parse(torch.__version__) >= version.parse("2.7"):
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return torch_row_wise_w8a8_scaled_mm
|
||||
|
||||
|
||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
"ChannelWiseTorchScaledMMLinearKernel cannot be used with "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
A.shape[0] == 1
|
||||
and B.shape[1] % 16 == 0
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
A,
|
||||
out_dtype,
|
||||
As,
|
||||
Bs,
|
||||
get_cu_count(),
|
||||
bias,
|
||||
)
|
||||
# Fallback
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs,
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list[int],
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A, B, out_dtype, As, Bs, bias
|
||||
)
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "ROCm"
|
||||
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
if not on_mi3xx():
|
||||
return False, "ROCm MI3xx"
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
|
||||
return (
|
||||
False,
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
|
||||
+ "to use ROCmScaledMMLinearKernel.",
|
||||
)
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
False,
|
||||
"ROCmScaledMMLinearKernel requires "
|
||||
+ "per tensor activation and weight scales.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return rocm_per_tensor_float_w8a8_scaled_mm
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
@ -11,46 +11,49 @@ from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cuda_alike():
|
||||
return True, None
|
||||
return False, "Requires ROCm or CUDA."
|
||||
return False, "ROCm or CUDA"
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return False, "Only symmetric input is supported."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q, _, i_s, _, _ = self._get_layer_params(layer)
|
||||
w_q_name, _, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
w_q_name,
|
||||
torch.nn.Parameter(w_q.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
assert i_s is not None
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
i_s_name,
|
||||
torch.nn.Parameter(i_s.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@ -58,7 +61,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
|
||||
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=True
|
||||
|
||||
@ -12,23 +12,21 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
def is_platform_supported(cls) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "Requires TPU."
|
||||
return False, "TPU"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.is_static_input_scheme:
|
||||
return False, "ScaledMMXLA requires dynamic activation scales."
|
||||
|
||||
@ -43,9 +41,10 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
@ -53,7 +52,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
|
||||
@ -61,14 +60,14 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
weight_scale = weight_scale.squeeze(-1)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# Only support symmetric dynamic activation quantization.
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
# Filter warning for cond usage in apply_weights. It is okay
|
||||
# to specialize the graph since bias is not dynamic.
|
||||
@ -89,7 +88,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
@ -34,6 +34,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
@ -71,10 +74,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
cutlass_fp4_supported,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@ -431,8 +436,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8StaticTensorSym,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
@ -500,13 +508,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
@ -520,8 +522,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
@ -578,13 +583,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
|
||||
|
||||
@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
)
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8DynamicTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@ -126,11 +130,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
input_scale_ub=None,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -7,10 +7,18 @@ from typing import Any, cast
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@ -23,6 +31,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
def __init__(
|
||||
@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
|
||||
per_token = (
|
||||
per_token_activation = (
|
||||
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
|
||||
)
|
||||
self.act_quant_group_shape = (
|
||||
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
|
||||
per_token_weight = self.weight_qscheme == "per_channel"
|
||||
|
||||
self.activation_quant_key = (
|
||||
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
|
||||
)
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_quant_group_shape,
|
||||
self.weight_quant_key = (
|
||||
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
@ -163,17 +174,19 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
layer.input_scale_ub = None
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -7,8 +7,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
init_int8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (
|
||||
@ -22,8 +21,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qscheme: str,
|
||||
@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj",
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
|
||||
@ -109,6 +109,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
|
||||
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
|
||||
|
||||
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
|
||||
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
|
||||
|
||||
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
|
||||
|
||||
@ -1,34 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
# The condition is determined once as the operations
|
||||
# are time-consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (
|
||||
current_platform.is_rocm()
|
||||
and version.parse(torch.__version__) >= version.parse("2.7")
|
||||
and current_platform.has_device_capability(94)
|
||||
)
|
||||
|
||||
|
||||
def sparse_cutlass_supported() -> bool:
|
||||
@ -140,361 +117,6 @@ def requantize_with_max_scale(
|
||||
return max_w_scale, weight
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm(
|
||||
*,
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(
|
||||
*,
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_scaled_fp8_mm(
|
||||
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
|
||||
)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_mi3xx()
|
||||
and qinput.shape[0] == 1
|
||||
and qinput.shape[1] % 16 == 0
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
weight.t(),
|
||||
qinput,
|
||||
out_dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
bias,
|
||||
)
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(
|
||||
*,
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias
|
||||
)
|
||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(
|
||||
*,
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch._scaled_mm(
|
||||
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_token_w8a8_scaled_mm(
|
||||
*,
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.t(),
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
|
||||
def torch_channelwise_w8a8_scaled_mm(
|
||||
*,
|
||||
qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
output_shape: list,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * scale_b.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if preferred_backend == "rocm":
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
if preferred_backend == "flashinfer":
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
if preferred_backend == "cutlass":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if (
|
||||
not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM
|
||||
):
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
# Normally, torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class Fp8LinearOp:
|
||||
"""
|
||||
This class executes a FP8 linear layer using cutlass if supported and
|
||||
torch.scaled_mm otherwise.
|
||||
It needs to be a class instead of a method so that config can be read
|
||||
in the __init__ method, as reading config is not allowed inside forward.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_quant_static: bool,
|
||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
||||
pad_output: bool | None = None,
|
||||
):
|
||||
if current_platform.is_rocm():
|
||||
self.preferred_backend = "rocm"
|
||||
elif current_platform.is_cuda() and cutlass_fp8_supported():
|
||||
if has_flashinfer() and current_platform.has_device_capability(100):
|
||||
self.preferred_backend = "flashinfer"
|
||||
else:
|
||||
self.preferred_backend = "cutlass"
|
||||
else:
|
||||
self.preferred_backend = "torch"
|
||||
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = (
|
||||
config.mode < CompilationMode.VLLM_COMPILE
|
||||
and self.preferred_backend == "torch"
|
||||
)
|
||||
|
||||
self.output_padding = 17 if pad_output else None
|
||||
self.act_quant_static = act_quant_static
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=act_quant_static,
|
||||
group_shape=act_quant_group_shape,
|
||||
num_token_padding=self.output_padding,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
input_scale_ub: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = input.dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
qinput, x_scale = self.quant_fp8(
|
||||
input_2d,
|
||||
input_scale,
|
||||
input_scale_ub,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
# Must have dim() conditions
|
||||
# In per-token quant scenario, when the number of token is 1,
|
||||
# the scale will only have 1 elements.
|
||||
# Without checking the dim(),
|
||||
# we cannot distingushes between per-tensor and per-token quant.
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.preferred_backend, per_tensor_weights, per_tensor_activations
|
||||
)
|
||||
|
||||
return w8a8_scaled_mm_func(
|
||||
qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user