mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 14:37:03 +08:00
reduce kernel init boilerplate
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
423e2a625e
commit
dd001064c0
@ -12,12 +12,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
_POSSIBLE_FP8_KERNELS,
|
init_fp8_linear_kernel,
|
||||||
choose_scaled_mm_linear_kernel,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
QUANT_STRATEGY_MAP,
|
QUANT_STRATEGY_MAP,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
W8A8BlockFp8LinearOp,
|
||||||
@ -82,22 +80,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
|
||||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
|
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
|
||||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||||
is_static_input_scheme=self.is_static_input_scheme,
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
weight_quant_strategy=weight_quant_strategy,
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
activation_group_shape=self.act_q_group_shape,
|
activation_group_shape=self.act_q_group_shape,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
)
|
)
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(
|
|
||||||
scaled_mm_linear_kernel_config,
|
|
||||||
_POSSIBLE_FP8_KERNELS,
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
self.kernel = kernel_type(
|
|
||||||
scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -212,4 +201,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.kernel.apply_weights(layer, x, bias)
|
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||||
|
|||||||
@ -43,8 +43,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
_POSSIBLE_FP8_KERNELS,
|
init_fp8_linear_kernel,
|
||||||
choose_scaled_mm_linear_kernel,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
@ -394,22 +393,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||||
is_static_input_scheme=self.act_q_static,
|
is_static_input_scheme=self.act_q_static,
|
||||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||||
activation_group_shape=self.act_q_group_shape,
|
activation_group_shape=self.act_q_group_shape,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
)
|
)
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(
|
|
||||||
scaled_mm_linear_kernel_config,
|
|
||||||
_POSSIBLE_FP8_KERNELS,
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fp8_linear_kernel = kernel_type(
|
|
||||||
scaled_mm_linear_kernel_config,
|
|
||||||
layer_param_names=["weight", "weight_scale", "input_scale"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||||
AiterScaledMMLinearKernel,
|
AiterScaledMMLinearKernel,
|
||||||
@ -17,8 +19,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
|||||||
ROCmScaledMMLinearKernel,
|
ROCmScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
FP8ScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
ScaledMMLinearLayerConfig,
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||||
ChannelWiseTorchScaledMMLinearKernel,
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
@ -32,6 +37,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
|||||||
XLAScaledMMLinearKernel,
|
XLAScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.platforms import PlatformEnum, current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -122,3 +128,28 @@ def choose_scaled_mm_linear_kernel(
|
|||||||
"Failed to find a kernel that can implement the "
|
"Failed to find a kernel that can implement the "
|
||||||
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_fp8_linear_kernel(
|
||||||
|
act_q_static: bool,
|
||||||
|
act_q_group_shape: GroupShape,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
module_name: str,
|
||||||
|
) -> FP8ScaledMMLinearKernel:
|
||||||
|
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||||
|
is_static_input_scheme=act_q_static,
|
||||||
|
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||||
|
activation_group_shape=act_q_group_shape,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config,
|
||||||
|
_POSSIBLE_FP8_KERNELS,
|
||||||
|
module_name=module_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel_type(
|
||||||
|
scaled_mm_linear_kernel_config,
|
||||||
|
layer_param_names=["weight", "weight_scale", "input_scale"],
|
||||||
|
)
|
||||||
|
|||||||
@ -9,11 +9,9 @@ from torch.nn import Parameter
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
_POSSIBLE_FP8_KERNELS,
|
init_fp8_linear_kernel,
|
||||||
choose_scaled_mm_linear_kernel,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
FP8ScaledMMLinearLayerConfig,
|
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
@ -174,24 +172,13 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
input_scale[:] = torch.finfo(torch.float32).min
|
input_scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("input_scale", input_scale)
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
|
||||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
|
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
|
||||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||||
is_static_input_scheme=self.is_static_input_scheme,
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
weight_quant_strategy=weight_quant_strategy,
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
activation_group_shape=self.act_quant_group_shape,
|
activation_group_shape=self.act_quant_group_shape,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
)
|
)
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(
|
|
||||||
scaled_mm_linear_kernel_config,
|
|
||||||
_POSSIBLE_FP8_KERNELS,
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
|
||||||
self.kernel = kernel_type(
|
|
||||||
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
@ -199,4 +186,4 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.kernel.apply_weights(layer, x, bias)
|
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user