mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 06:37:04 +08:00
implement apply func in base FP8ScaledMMLinearKernel class
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
aaa0d55587
commit
7fb465744c
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -98,12 +99,9 @@ class FP8ScaledMMLinearKernel(
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_ouput_padding(),
|
||||
)
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
@abstractmethod
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
@ -121,6 +119,53 @@ class FP8ScaledMMLinearKernel(
|
||||
getattr(layer, x_s_ub),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm_func = self.get_scaled_mm_func()
|
||||
quant_fp8 = self.quant_fp8
|
||||
fp8_dtype = self.fp8_dtype
|
||||
maybe_out_dtype = self.config.out_dtype
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != fp8_dtype:
|
||||
x_2d_q, x_s = quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
x_s_ub,
|
||||
)
|
||||
return scaled_mm_func(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@ -17,7 +19,6 @@ from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm_fp8(
|
||||
@ -160,9 +161,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
@ -174,21 +172,8 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
cutlass_w8a8_scaled_mm_fp8,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return cutlass_w8a8_scaled_mm_fp8
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@ -10,7 +12,6 @@ from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(
|
||||
@ -29,13 +30,6 @@ def flashinfer_w8a8_scaled_mm(
|
||||
|
||||
|
||||
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
@ -71,21 +65,12 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
flashinfer_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
@ -11,7 +13,6 @@ from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
@ -155,24 +156,8 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_per_tensor_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
|
||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@ -209,24 +194,8 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_row_wise_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return torch_row_wise_w8a8_scaled_mm
|
||||
|
||||
|
||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@ -245,21 +214,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_channelwise_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -12,7 +14,6 @@ from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
@ -86,9 +87,6 @@ direct_register_custom_op(
|
||||
|
||||
|
||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# TODO: check if this causes an issue on non-ROCM platforms
|
||||
@ -125,21 +123,8 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
rocm_per_tensor_float_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
x_s_ub,
|
||||
self.config.out_dtype,
|
||||
)
|
||||
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||
return rocm_per_tensor_float_w8a8_scaled_mm
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def apply_weights_fp8(
|
||||
scaled_mm_func: Callable[..., torch.Tensor],
|
||||
quant_fp8_func: QuantFP8,
|
||||
w: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_s: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
x_s_ub: torch.Tensor | None,
|
||||
maybe_out_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||
|
||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = quant_fp8_func(
|
||||
x_2d,
|
||||
x_s,
|
||||
x_s_ub,
|
||||
)
|
||||
|
||||
return scaled_mm_func(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user