first try

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-28 16:26:51 +00:00
parent a806c14cc7
commit 974e6820ce
8 changed files with 924 additions and 125 deletions

View File

@ -10,6 +10,14 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
@ -24,7 +32,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
maybe_create_device_identity,
)
@ -72,9 +79,32 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape,
param_name_list = ["weight", "weight_scale", "input_scale"]
layer_mapping_function = lambda layer: (
tuple(getattr(layer, param_name) for param_name in param_name_list),
param_name_list,
)
# TODO: clean up
if self.strategy == QuantizationStrategy.TENSOR:
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
elif self.strategy == QuantizationStrategy.CHANNEL:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=True,
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype,
)
kernel = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
)
self.fp8_linear = kernel(
scaled_mm_linear_kernel_config, layer_mapping_function
)
@classmethod
@ -190,11 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias)

View File

@ -2,16 +2,36 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
class ScaledMMLinearQuantStrategy(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
BLOCK = "block"
def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1
def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1
@dataclass
class ScaledMMLinearLayerConfig:
# TODO: remove is channelwise
is_channelwise: bool
is_static_input_scheme: bool
input_symmetric: bool
out_dtype: torch.dtype | None
weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
class ScaledMMLinearKernel(ABC):
@ -26,21 +46,11 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError
def __init__(
self,
c: ScaledMMLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name
self.layer_mapping_function = layer_mapping_function
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@ -55,19 +65,19 @@ class ScaledMMLinearKernel(ABC):
) -> torch.Tensor:
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name),
)
# def _get_weight_params(
# self, layer: torch.nn.Module
# ) -> tuple[
# torch.Tensor, # weight
# torch.Tensor, # weight_scale
# torch.Tensor | None, # input_scale,
# torch.Tensor | None, # input_zp
# torch.Tensor | None, # azp_adj
# ]:
# return (
# getattr(layer, self.w_q_name),
# getattr(layer, self.w_s_name),
# getattr(layer, self.i_s_name),
# getattr(layer, self.i_zp_name),
# getattr(layer, self.azp_adj_name),
# )

View File

@ -12,10 +12,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
@ -25,16 +33,28 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [
ROCmScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
ChannelWiseTorchScaledMMLinearKernel,
],
}
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
config: ScaledMMLinearLayerConfig,
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
compute_capability: int | None = None,
) -> type[ScaledMMLinearKernel]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
@ -61,7 +81,7 @@ def choose_scaled_mm_linear_kernel(
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
for kernel in possible_kernels[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"

View File

@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
from .cutlass import process_weights_after_loading
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
@ -52,7 +52,7 @@ if current_platform.is_rocm():
)
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@ -92,7 +92,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
_, param_names = self.layer_mapping_function(layer)
process_weights_after_loading(self.config, layer, *param_names)
def apply_weights(
self,
@ -110,7 +112,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.

View File

@ -2,10 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
@ -14,6 +18,111 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
def cutlass_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
return output.view(*output_shape)
def process_weights_after_loading(
config: ScaledMMLinearLayerConfig,
layer: torch.nn.Module,
w_q_name: str,
w_s_name: str,
i_s_name: str,
i_zp_name: str,
azp_adj_name: str,
):
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if config.input_symmetric:
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, i_zp_name, None)
else:
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not config.input_symmetric:
weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr(
layer,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, azp_adj_name, None)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@ -27,83 +136,9 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
_, param_names = self.layer_mapping_function(layer)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
process_weights_after_loading(self.config, layer, *param_names)
def apply_weights(
self,
@ -111,7 +146,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
@ -138,3 +173,70 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return (
False,
"CutlassFP8ScaledMMLinearKernel is supported "
+ "on CUDA platforms Only.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return cutlass_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
def flashinfer_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if not current_platform.is_cuda():
return (
False,
"FlashInferScaledMMLinearKernel is supported "
+ "on CUDA platforms Only.",
)
if not has_flashinfer():
return (
False,
"FlashInferScaledMMLinearKernel requires "
+ "FlashInfer to be installed.",
)
if not has_flashinfer():
return (
False,
"FlashInferScaledMMLinearKernel requires "
+ "FlashInfer to be installed.",
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
False,
"FlashInferScaledMMLinearKernel requires "
+ "per tensor activation and weight scales.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return flashinfer_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] == 1
and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
B.t(),
A,
out_dtype,
As,
Bs,
current_platform.get_cu_count(),
bias,
)
# Fallabck
else:
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs,
bias=bias,
)
return output
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
def rocm_per_tensor_float_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list[int],
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias
)
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if not current_platform.is_rocm():
return (
False,
"ROCmScaledMMLinearFP8Kernel is supported " + "on ROCm platforms Only.",
)
if not on_mi3xx():
return (
False,
"ROCmScaledMMLinearFP8Kernel is supported "
+ "on MI3xx architures only.",
)
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
return (
False,
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
+ "to use ROCmScaledMMLinearKernel ",
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
False,
"ROCmScaledMMLinearKernel requires "
+ "per tensor activation and weight scales.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return rocm_per_tensor_float_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def torch_per_tensor_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
def torch_row_wise_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
bias=bias,
)
output = torch.narrow(output, 0, 0, A.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A,
B,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, A.shape[0])
x_scale = torch.narrow(As, 0, 0, A.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * Bs.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
output_padding = 17 if pad_output else None
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=output_padding,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return (
False,
"PerTensorTorchScaledMMLinearKernel requires "
+ "per tensor activation and weight scales.",
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_per_tensor_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 94
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if per_tensor_activation_scales and per_tensor_weight_scales:
return (
False,
"RowWiseTorchScaledMMLinearKernel cannot be used with "
+ "per tensor activation and weight scales.",
)
if not current_platform.is_rocm():
return (
False,
"RowWiseTorchScaledMMLinearKernel is only supported "
+ "in ROCm platforms.",
)
if not version.parse(torch.__version__) >= version.parse("2.7"):
return (
False,
"RowWiseTorchScaledMMLinearKernel requires " + "pytorch version >=2.7.",
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_row_wise_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 94
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if per_tensor_activation_scales and per_tensor_weight_scales:
return (
False,
"ChannelWiseTorchScaledMMLinearKernel cannot be used with "
+ "per tensor activation and weight scales.",
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_channelwise_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)