implement apply func in base FP8ScaledMMLinearKernel class

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-07 07:17:41 +00:00
parent aaa0d55587
commit 7fb465744c
6 changed files with 83 additions and 179 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, TypeVar from typing import Generic, TypeVar
@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
) )
from vllm.platforms import current_platform
@dataclass @dataclass
@ -98,12 +99,9 @@ class FP8ScaledMMLinearKernel(
group_shape=act_scale_descriptor.group_shape, group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_ouput_padding(), num_token_padding=self.get_ouput_padding(),
) )
self.fp8_dtype = current_platform.fp8_dtype()
super().__init__(c, layer_param_names) super().__init__(c, layer_param_names)
@abstractmethod
def get_ouput_padding(self) -> int | None:
raise NotImplementedError
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# lovelace and up # lovelace and up
@ -121,6 +119,53 @@ class FP8ScaledMMLinearKernel(
getattr(layer, x_s_ub), getattr(layer, x_s_ub),
) )
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
scaled_mm_func = self.get_scaled_mm_func()
quant_fp8 = self.quant_fp8
fp8_dtype = self.fp8_dtype
maybe_out_dtype = self.config.out_dtype
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != fp8_dtype:
x_2d_q, x_s = quant_fp8(
x_2d,
x_s,
x_s_ub,
)
return scaled_mm_func(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
@abstractmethod
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
raise NotImplementedError
@abstractmethod
def get_ouput_padding(self) -> int | None:
raise NotImplementedError
class Int8ScaledMMLinearKernel( class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
@ -17,7 +19,6 @@ from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel, Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig,
) )
from .utils import apply_weights_fp8
def cutlass_w8a8_scaled_mm_fp8( def cutlass_w8a8_scaled_mm_fp8(
@ -160,9 +161,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
@ -174,21 +172,8 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
return True, None return True, None
def apply_weights( def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
self, return cutlass_w8a8_scaled_mm_fp8
layer: torch.nn.Module,
x: torch.Tensor, def get_ouput_padding(self) -> int | None:
bias: torch.Tensor | None = None, return None
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
cutlass_w8a8_scaled_mm_fp8,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -10,7 +12,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
) )
from .utils import apply_weights_fp8
def flashinfer_w8a8_scaled_mm( def flashinfer_w8a8_scaled_mm(
@ -29,13 +30,6 @@ def flashinfer_w8a8_scaled_mm(
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = ( per_tensor_activation_scales = (
@ -71,21 +65,12 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
) )
return True, None return True, None
def apply_weights( @classmethod
self, def get_min_capability(cls) -> int:
layer: torch.nn.Module, return 100
x: torch.Tensor,
bias: torch.Tensor | None = None, def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
): return flashinfer_w8a8_scaled_mm
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8( def get_ouput_padding(self) -> int | None:
flashinfer_w8a8_scaled_mm, return None
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from packaging import version from packaging import version
@ -11,7 +13,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
) )
from .utils import apply_weights_fp8
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@ -155,24 +156,8 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
) )
return True, None return True, None
def apply_weights( def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
self, return torch_per_tensor_w8a8_scaled_mm
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_per_tensor_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@ -209,24 +194,8 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return True, None return True, None
def apply_weights( def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
self, return torch_row_wise_w8a8_scaled_mm
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_row_wise_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@ -245,21 +214,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return True, None return True, None
def apply_weights( def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
self, return torch_channelwise_w8a8_scaled_mm
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
torch_channelwise_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
import vllm.envs as envs import vllm.envs as envs
@ -12,7 +14,6 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
) )
from .utils import apply_weights_fp8
def rocm_per_tensor_float_w8a8_scaled_mm_impl( def rocm_per_tensor_float_w8a8_scaled_mm_impl(
@ -86,9 +87,6 @@ direct_register_custom_op(
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms # TODO: check if this causes an issue on non-ROCM platforms
@ -125,21 +123,8 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
) )
return True, None return True, None
def apply_weights( def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
self, return rocm_per_tensor_float_w8a8_scaled_mm
layer: torch.nn.Module,
x: torch.Tensor, def get_ouput_padding(self) -> int | None:
bias: torch.Tensor | None = None, return None
):
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
return apply_weights_fp8(
rocm_per_tensor_float_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
x_s_ub,
self.config.out_dtype,
)

View File

@ -1,49 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.platforms import current_platform
def apply_weights_fp8(
scaled_mm_func: Callable[..., torch.Tensor],
quant_fp8_func: QuantFP8,
w: torch.Tensor,
x: torch.Tensor,
w_s: torch.Tensor,
x_s: torch.Tensor,
bias: torch.Tensor,
x_s_ub: torch.Tensor | None,
maybe_out_dtype: torch.dtype | None = None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = quant_fp8_func(
x_2d,
x_s,
x_s_ub,
)
return scaled_mm_func(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)