implement apply func in base FP8ScaledMMLinearKernel class

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-07 07:17:41 +00:00
parent aaa0d55587
commit 7fb465744c
6 changed files with 83 additions and 179 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
@dataclass
@ -98,12 +99,9 @@ class FP8ScaledMMLinearKernel(
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_ouput_padding(),
)
self.fp8_dtype = current_platform.fp8_dtype()
super().__init__(c, layer_param_names)
@abstractmethod
def get_ouput_padding(self) -> int | None:
raise NotImplementedError
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
@ -121,6 +119,53 @@ class FP8ScaledMMLinearKernel(
getattr(layer, x_s_ub),
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
scaled_mm_func = self.get_scaled_mm_func()
quant_fp8 = self.quant_fp8
fp8_dtype = self.fp8_dtype
maybe_out_dtype = self.config.out_dtype
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != fp8_dtype:
x_2d_q, x_s = quant_fp8(
x_2d,
x_s,
x_s_ub,
)
return scaled_mm_func(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
@abstractmethod
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
raise NotImplementedError
@abstractmethod
def get_ouput_padding(self) -> int | None:
raise NotImplementedError
class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm import _custom_ops as ops
@ -17,7 +19,6 @@ from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
from .utils import apply_weights_fp8
def cutlass_w8a8_scaled_mm_fp8(
@ -160,9 +161,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
@ -174,21 +172,8 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
cutlass_w8a8_scaled_mm_fp8,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return cutlass_w8a8_scaled_mm_fp8
def get_ouput_padding(self) -> int | None:
return None

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.platforms import current_platform
@ -10,7 +12,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from .utils import apply_weights_fp8
def flashinfer_w8a8_scaled_mm(
@ -29,13 +30,6 @@ def flashinfer_w8a8_scaled_mm(
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
@ -71,21 +65,12 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
flashinfer_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
@classmethod
def get_min_capability(cls) -> int:
return 100
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return flashinfer_w8a8_scaled_mm
def get_ouput_padding(self) -> int | None:
return None

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from packaging import version
@ -11,7 +13,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from .utils import apply_weights_fp8
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@ -155,24 +156,8 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_per_tensor_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return torch_per_tensor_w8a8_scaled_mm
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@ -209,24 +194,8 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_row_wise_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return torch_row_wise_w8a8_scaled_mm
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@ -245,21 +214,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_channelwise_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return torch_channelwise_w8a8_scaled_mm

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import vllm.envs as envs
@ -12,7 +14,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from .utils import apply_weights_fp8
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
@ -86,9 +87,6 @@ direct_register_custom_op(
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms
@ -125,21 +123,8 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
)
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
rocm_per_tensor_float_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
return rocm_per_tensor_float_w8a8_scaled_mm
def get_ouput_padding(self) -> int | None:
return None

View File

@ -1,49 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.platforms import current_platform
def apply_weights_fp8(
scaled_mm_func: Callable[..., torch.Tensor],
quant_fp8_func: QuantFP8,
w: torch.Tensor,
x: torch.Tensor,
w_s: torch.Tensor,
x_s: torch.Tensor,
bias: torch.Tensor,
x_s_ub: torch.Tensor | None,
maybe_out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = quant_fp8_func(
x_2d,
x_s,
x_s_ub,
)
return scaled_mm_func(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)