reduce logging boilerplate; update fp8 path

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-31 14:07:09 +00:00
parent 38825fce0f
commit 423e2a625e
7 changed files with 41 additions and 43 deletions

View File

@ -54,8 +54,6 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
@ -95,14 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
module_name=self.__class__.__name__,
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info(
"Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__
)
self._kernel_backends_being_used.add(kernel_type.__name__)
self.kernel = kernel_type(
scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)

View File

@ -28,8 +28,6 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
@ -60,13 +58,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS
scaled_mm_linear_kernel_config,
_POSSIBLE_INT8_KERNELS,
module_name=self.__class__.__name__,
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(

View File

@ -42,6 +42,14 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
@ -77,7 +85,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
all_close_1d,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
@ -387,9 +394,21 @@ class Fp8LinearMethod(LinearMethodBase):
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=self.act_q_static,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
module_name=self.__class__.__name__,
)
self.fp8_linear_kernel = kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale"],
)
def create_weights(
@ -674,14 +693,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
class Fp8MoEMethod(FusedMoEMethodBase):

View File

@ -3,6 +3,7 @@
import os
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
)
@ -32,6 +33,8 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
)
from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
# in priority/performance order (when available)
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
@ -54,6 +57,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig,
possible_kernels: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]],
module_name: str,
compute_capability: int | None = None,
) -> type[ScaledMMLinearKernel]:
"""
@ -105,6 +109,9 @@ def choose_scaled_mm_linear_kernel(
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
logger.info_once(
"Selected %s for %s", kernel.__name__, module_name, scope="global"
)
return kernel
else:
failure_reasons.append(

View File

@ -4,15 +4,15 @@ from collections.abc import Callable
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.platforms import current_platform
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
def apply_weights_fp8(
scaled_mm_func: FP8ScaledMMCallBack,
quant_fp8_func: FP8QuantCallback,
quant_fp8_func: QuantFP8,
w: torch.Tensor,
x: torch.Tensor,
w_s: torch.Tensor,

View File

@ -40,8 +40,6 @@ QUANT_STRATEGY_MAP = {
class QuarkW8A8Fp8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
):
@ -187,12 +185,9 @@ class QuarkW8A8Fp8(QuarkScheme):
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
module_name=self.__class__.__name__,
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
layer_param_names = ["weight", "weight_scale", "input_scale"]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names

View File

@ -25,8 +25,6 @@ logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self,
qscheme: str,
@ -60,13 +58,11 @@ class QuarkW8A8Int8(QuarkScheme):
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS
scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_INT8_KERNELS,
module_name=self.__class__.__name__,
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(