diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index b7aed6105d10c..8c0f0e1d57fb3 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -133,6 +133,14 @@ def torch_channelwise_w8a8_scaled_mm( class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + Base class for FP8 linear kernels using Torch. + Each subclass represents a kernel variant for + specific device capabilities and torch versions, + so we split them up and implement + get_min_capability() separately for each. + """ + def get_ouput_padding(self) -> int | None: vllm_config = get_current_vllm_config().compilation_config pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE