mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 21:51:24 +08:00
implement apply func in base FP8ScaledMMLinearKernel class
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
aaa0d55587
commit
7fb465744c
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -98,12 +99,9 @@ class FP8ScaledMMLinearKernel(
|
|||||||
group_shape=act_scale_descriptor.group_shape,
|
group_shape=act_scale_descriptor.group_shape,
|
||||||
num_token_padding=self.get_ouput_padding(),
|
num_token_padding=self.get_ouput_padding(),
|
||||||
)
|
)
|
||||||
|
self.fp8_dtype = current_platform.fp8_dtype()
|
||||||
super().__init__(c, layer_param_names)
|
super().__init__(c, layer_param_names)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_ouput_padding(self) -> int | None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# lovelace and up
|
# lovelace and up
|
||||||
@ -121,6 +119,53 @@ class FP8ScaledMMLinearKernel(
|
|||||||
getattr(layer, x_s_ub),
|
getattr(layer, x_s_ub),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
scaled_mm_func = self.get_scaled_mm_func()
|
||||||
|
quant_fp8 = self.quant_fp8
|
||||||
|
fp8_dtype = self.fp8_dtype
|
||||||
|
maybe_out_dtype = self.config.out_dtype
|
||||||
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
|
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||||
|
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||||
|
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != fp8_dtype:
|
||||||
|
x_2d_q, x_s = quant_fp8(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
x_s_ub,
|
||||||
|
)
|
||||||
|
return scaled_mm_func(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_ouput_padding(self) -> int | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Int8ScaledMMLinearKernel(
|
class Int8ScaledMMLinearKernel(
|
||||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@ -17,7 +19,6 @@ from .ScaledMMLinearKernel import (
|
|||||||
Int8ScaledMMLinearKernel,
|
Int8ScaledMMLinearKernel,
|
||||||
Int8ScaledMMLinearLayerConfig,
|
Int8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
from .utils import apply_weights_fp8
|
|
||||||
|
|
||||||
|
|
||||||
def cutlass_w8a8_scaled_mm_fp8(
|
def cutlass_w8a8_scaled_mm_fp8(
|
||||||
@ -160,9 +161,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
|
|
||||||
|
|
||||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def get_ouput_padding(self) -> int | None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
@ -174,21 +172,8 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def apply_weights(
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
self,
|
return cutlass_w8a8_scaled_mm_fp8
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
def get_ouput_padding(self) -> int | None:
|
||||||
bias: torch.Tensor | None = None,
|
return None
|
||||||
):
|
|
||||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
|
||||||
return apply_weights_fp8(
|
|
||||||
cutlass_w8a8_scaled_mm_fp8,
|
|
||||||
self.quant_fp8,
|
|
||||||
w,
|
|
||||||
x,
|
|
||||||
w_s,
|
|
||||||
x_s,
|
|
||||||
bias,
|
|
||||||
x_s_ub,
|
|
||||||
self.config.out_dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -10,7 +12,6 @@ from .ScaledMMLinearKernel import (
|
|||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
from .utils import apply_weights_fp8
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_w8a8_scaled_mm(
|
def flashinfer_w8a8_scaled_mm(
|
||||||
@ -29,13 +30,6 @@ def flashinfer_w8a8_scaled_mm(
|
|||||||
|
|
||||||
|
|
||||||
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def get_ouput_padding(self) -> int | None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
return 100
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
per_tensor_activation_scales = (
|
per_tensor_activation_scales = (
|
||||||
@ -71,21 +65,12 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def apply_weights(
|
@classmethod
|
||||||
self,
|
def get_min_capability(cls) -> int:
|
||||||
layer: torch.nn.Module,
|
return 100
|
||||||
x: torch.Tensor,
|
|
||||||
bias: torch.Tensor | None = None,
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
):
|
return flashinfer_w8a8_scaled_mm
|
||||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
|
||||||
return apply_weights_fp8(
|
def get_ouput_padding(self) -> int | None:
|
||||||
flashinfer_w8a8_scaled_mm,
|
return None
|
||||||
self.quant_fp8,
|
|
||||||
w,
|
|
||||||
x,
|
|
||||||
w_s,
|
|
||||||
x_s,
|
|
||||||
bias,
|
|
||||||
x_s_ub,
|
|
||||||
self.config.out_dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@ -11,7 +13,6 @@ from .ScaledMMLinearKernel import (
|
|||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
from .utils import apply_weights_fp8
|
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
@ -155,24 +156,8 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def apply_weights(
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
self,
|
return torch_per_tensor_w8a8_scaled_mm
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
):
|
|
||||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
|
||||||
return apply_weights_fp8(
|
|
||||||
torch_per_tensor_w8a8_scaled_mm,
|
|
||||||
self.quant_fp8,
|
|
||||||
w,
|
|
||||||
x,
|
|
||||||
w_s,
|
|
||||||
x_s,
|
|
||||||
bias,
|
|
||||||
x_s_ub,
|
|
||||||
self.config.out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@ -209,24 +194,8 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def apply_weights(
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
self,
|
return torch_row_wise_w8a8_scaled_mm
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
):
|
|
||||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
|
||||||
return apply_weights_fp8(
|
|
||||||
torch_row_wise_w8a8_scaled_mm,
|
|
||||||
self.quant_fp8,
|
|
||||||
w,
|
|
||||||
x,
|
|
||||||
w_s,
|
|
||||||
x_s,
|
|
||||||
bias,
|
|
||||||
x_s_ub,
|
|
||||||
self.config.out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@ -245,21 +214,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def apply_weights(
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
self,
|
return torch_channelwise_w8a8_scaled_mm
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
):
|
|
||||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
|
||||||
return apply_weights_fp8(
|
|
||||||
torch_channelwise_w8a8_scaled_mm,
|
|
||||||
self.quant_fp8,
|
|
||||||
w,
|
|
||||||
x,
|
|
||||||
w_s,
|
|
||||||
x_s,
|
|
||||||
bias,
|
|
||||||
x_s_ub,
|
|
||||||
self.config.out_dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -12,7 +14,6 @@ from .ScaledMMLinearKernel import (
|
|||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
from .utils import apply_weights_fp8
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||||
@ -86,9 +87,6 @@ direct_register_custom_op(
|
|||||||
|
|
||||||
|
|
||||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def get_ouput_padding(self) -> int | None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
# TODO: check if this causes an issue on non-ROCM platforms
|
# TODO: check if this causes an issue on non-ROCM platforms
|
||||||
@ -125,21 +123,8 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def apply_weights(
|
def get_scaled_mm_func(self) -> Callable[..., torch.Tensor]:
|
||||||
self,
|
return rocm_per_tensor_float_w8a8_scaled_mm
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
def get_ouput_padding(self) -> int | None:
|
||||||
bias: torch.Tensor | None = None,
|
return None
|
||||||
):
|
|
||||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
|
||||||
return apply_weights_fp8(
|
|
||||||
rocm_per_tensor_float_w8a8_scaled_mm,
|
|
||||||
self.quant_fp8,
|
|
||||||
w,
|
|
||||||
x,
|
|
||||||
w_s,
|
|
||||||
x_s,
|
|
||||||
bias,
|
|
||||||
x_s_ub,
|
|
||||||
self.config.out_dtype,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,49 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
def apply_weights_fp8(
|
|
||||||
scaled_mm_func: Callable[..., torch.Tensor],
|
|
||||||
quant_fp8_func: QuantFP8,
|
|
||||||
w: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_s: torch.Tensor,
|
|
||||||
x_s: torch.Tensor,
|
|
||||||
bias: torch.Tensor,
|
|
||||||
x_s_ub: torch.Tensor | None,
|
|
||||||
maybe_out_dtype: torch.dtype | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
||||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
|
||||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
|
||||||
# View input as 2D matrix for fp8 methods
|
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
|
||||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
|
||||||
|
|
||||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
|
||||||
|
|
||||||
# If input not quantized
|
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = quant_fp8_func(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
x_s_ub,
|
|
||||||
)
|
|
||||||
|
|
||||||
return scaled_mm_func(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user