first try

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-28 16:26:51 +00:00
parent a806c14cc7
commit 974e6820ce
8 changed files with 924 additions and 125 deletions

View File

@ -10,6 +10,14 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support, check_aiter_fp8_linear_support,
@ -24,7 +32,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
maybe_create_device_identity, maybe_create_device_identity,
) )
@ -72,9 +79,32 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: else:
self.fp8_linear = Fp8LinearOp( param_name_list = ["weight", "weight_scale", "input_scale"]
act_quant_static=self.is_static_input_scheme, layer_mapping_function = lambda layer: (
act_quant_group_shape=self.act_q_group_shape, tuple(getattr(layer, param_name) for param_name in param_name_list),
param_name_list,
)
# TODO: clean up
if self.strategy == QuantizationStrategy.TENSOR:
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
elif self.strategy == QuantizationStrategy.CHANNEL:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=True,
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype,
)
kernel = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
)
self.fp8_linear = kernel(
scaled_mm_linear_kernel_config, layer_mapping_function
) )
@classmethod @classmethod
@ -190,11 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias, bias=bias,
) )
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)

View File

@ -2,16 +2,36 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
class ScaledMMLinearQuantStrategy(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
BLOCK = "block"
def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1
def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1
@dataclass @dataclass
class ScaledMMLinearLayerConfig: class ScaledMMLinearLayerConfig:
# TODO: remove is channelwise
is_channelwise: bool is_channelwise: bool
is_static_input_scheme: bool is_static_input_scheme: bool
input_symmetric: bool input_symmetric: bool
out_dtype: torch.dtype | None
weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
class ScaledMMLinearKernel(ABC): class ScaledMMLinearKernel(ABC):
@ -26,21 +46,11 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError raise NotImplementedError
def __init__( def __init__(
self, self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
c: ScaledMMLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
) -> None: ) -> None:
assert self.can_implement(c) assert self.can_implement(c)
self.config = c self.config = c
self.w_q_name = w_q_param_name self.layer_mapping_function = layer_mapping_function
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name
@abstractmethod @abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@ -55,19 +65,19 @@ class ScaledMMLinearKernel(ABC):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def _get_weight_params( # def _get_weight_params(
self, layer: torch.nn.Module # self, layer: torch.nn.Module
) -> tuple[ # ) -> tuple[
torch.Tensor, # weight # torch.Tensor, # weight
torch.Tensor, # weight_scale # torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale, # torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp # torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj # torch.Tensor | None, # azp_adj
]: # ]:
return ( # return (
getattr(layer, self.w_q_name), # getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name), # getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name), # getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name), # getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name), # getattr(layer, self.azp_adj_name),
) # )

View File

@ -12,10 +12,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel, CutlassScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel, TritonScaledMMLinearKernel,
) )
@ -25,16 +33,28 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
from vllm.platforms import PlatformEnum, current_platform from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel],
} }
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [
ROCmScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
ChannelWiseTorchScaledMMLinearKernel,
],
}
def choose_scaled_mm_linear_kernel( def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None config: ScaledMMLinearLayerConfig,
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
compute_capability: int | None = None,
) -> type[ScaledMMLinearKernel]: ) -> type[ScaledMMLinearKernel]:
""" """
Choose an ScaledMMLinearKernel that can implement the given config for the Choose an ScaledMMLinearKernel that can implement the given config for the
@ -61,7 +81,7 @@ def choose_scaled_mm_linear_kernel(
compute_capability = _cc[0] * 10 + _cc[1] compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = [] failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]: for kernel in possible_kernels[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append( failure_reasons.append(
f" {kernel.__name__} disabled by environment variable" f" {kernel.__name__} disabled by environment variable"

View File

@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import process_weights_after_loading
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl( def rocm_aiter_gemm_w8a8_impl(
@ -52,7 +52,7 @@ if current_platform.is_rocm():
) )
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 90 return 90
@ -92,7 +92,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) _, param_names = self.layer_mapping_function(layer)
process_weights_after_loading(self.config, layer, *param_names)
def apply_weights( def apply_weights(
self, self,
@ -110,7 +112,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM. ATIER block scaled GEMM and mix-precision GEMM.
""" """
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
# ops.scaled_int8_quant supports both dynamic and static quant: # ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x. # * dynamic, i_s is None and x_s computed from x.

View File

@ -2,10 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, convert_to_channelwise,
) )
@ -14,6 +18,111 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
def cutlass_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
return output.view(*output_shape)
def process_weights_after_loading(
config: ScaledMMLinearLayerConfig,
layer: torch.nn.Module,
w_q_name: str,
w_s_name: str,
i_s_name: str,
i_zp_name: str,
azp_adj_name: str,
):
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if config.input_symmetric:
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, i_zp_name, None)
else:
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not config.input_symmetric:
weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr(
layer,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, azp_adj_name, None)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
@ -27,83 +136,9 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT _, param_names = self.layer_mapping_function(layer)
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE process_weights_after_loading(self.config, layer, *param_names)
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
def apply_weights( def apply_weights(
self, self,
@ -111,7 +146,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) (w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
# ops.scaled_int8_quant supports both dynamic and static quant: # ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x. # * dynamic, i_s is None and x_s computed from x.
@ -138,3 +173,70 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return ops.cutlass_scaled_mm( return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
) )
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return (
False,
"CutlassFP8ScaledMMLinearKernel is supported "
+ "on CUDA platforms Only.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return cutlass_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
def flashinfer_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if not current_platform.is_cuda():
return (
False,
"FlashInferScaledMMLinearKernel is supported "
+ "on CUDA platforms Only.",
)
if not has_flashinfer():
return (
False,
"FlashInferScaledMMLinearKernel requires "
+ "FlashInfer to be installed.",
)
if not has_flashinfer():
return (
False,
"FlashInferScaledMMLinearKernel requires "
+ "FlashInfer to be installed.",
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
False,
"FlashInferScaledMMLinearKernel requires "
+ "per tensor activation and weight scales.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return flashinfer_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] == 1
and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
B.t(),
A,
out_dtype,
As,
Bs,
current_platform.get_cu_count(),
bias,
)
# Fallabck
else:
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs,
bias=bias,
)
return output
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
def rocm_per_tensor_float_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list[int],
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias
)
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if not current_platform.is_rocm():
return (
False,
"ROCmScaledMMLinearFP8Kernel is supported " + "on ROCm platforms Only.",
)
if not on_mi3xx():
return (
False,
"ROCmScaledMMLinearFP8Kernel is supported "
+ "on MI3xx architures only.",
)
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
return (
False,
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
+ "to use ROCmScaledMMLinearKernel ",
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
False,
"ROCmScaledMMLinearKernel requires "
+ "per tensor activation and weight scales.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return rocm_per_tensor_float_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def torch_per_tensor_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
def torch_row_wise_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
bias=bias,
)
output = torch.narrow(output, 0, 0, A.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A,
B,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, A.shape[0])
x_scale = torch.narrow(As, 0, 0, A.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * Bs.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
output_padding = 17 if pad_output else None
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=output_padding,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
False,
"PerTensorTorchScaledMMLinearKernel requires "
+ "per tensor activation and weight scales.",
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_per_tensor_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 94
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if per_tensor_activation_scales and per_tensor_weight_scales:
return (
False,
"RowWiseTorchScaledMMLinearKernel cannot be used with "
+ "per tensor activation and weight scales.",
)
if not current_platform.is_rocm():
return (
False,
"RowWiseTorchScaledMMLinearKernel is only supported "
+ "in ROCm platforms.",
)
if not version.parse(torch.__version__) >= version.parse("2.7"):
return (
False,
"RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.",
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_row_wise_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 94
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if per_tensor_activation_scales and per_tensor_weight_scales:
return (
False,
"ChannelWiseTorchScaledMMLinearKernel cannot be used with "
+ "per tensor activation and weight scales.",
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_channelwise_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)