clean up; fix quark path

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-30 12:27:04 +00:00
parent e54e572085
commit c05027f67a
12 changed files with 357 additions and 433 deletions

View File

@ -6,6 +6,7 @@ from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
@ -17,6 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
QUANT_STRATEGY_MAP,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@ -49,8 +51,11 @@ strategy_to_parameter_type = {
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
@ -79,19 +84,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
param_name_list = ["weight", "weight_scale", "input_scale"]
layer_mapping_function = lambda layer: (
tuple(getattr(layer, param_name) for param_name in param_name_list),
param_name_list,
)
# TODO: clean up
if self.strategy == QuantizationStrategy.TENSOR:
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
elif self.strategy == QuantizationStrategy.CHANNEL:
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
layer_param_names = ["weight", "weight_scale", "input_scale"]
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=self.is_static_input_scheme,
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype,
@ -101,9 +97,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
_POSSIBLE_FP8_KERNELS,
)
self.fp8_linear = kernel(
scaled_mm_linear_kernel_config, layer_mapping_function
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
)
if kernel.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__)
self._kernel_backends_being_used.add(kernel.__name__)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up

View File

@ -113,15 +113,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
layer_mapping_function = lambda layer: (
tuple(getattr(layer, param_name) for param_name in param_name_list),
param_name_list,
)
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
layer_mapping_function = layer_mapping_function
layer_param_names = layer_param_names
)
# Checkpoints are serialized in compressed-tensors format, which is

View File

@ -5,10 +5,12 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from typing import Generic, TypeVar
from typing import Generic, Sequence, TypeVar
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
class ScaledMMLinearQuantStrategy(Enum):
@ -16,21 +18,19 @@ class ScaledMMLinearQuantStrategy(Enum):
CHANNEL = "channel"
BLOCK = "block"
def is_per_token(self) -> bool:
return self.row == 1 and self.col == -1
def is_per_group(self) -> bool:
return self.row == 1 and self.col >= 1
QUANT_STRATEGY_MAP = {
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
}
@dataclass
class ScaledMMLinearLayerConfig:
pass
is_static_input_scheme: bool
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_channelwise: bool
is_static_input_scheme: bool
input_symmetric: bool
@dataclass
@ -40,10 +40,24 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
out_dtype: torch.dtype
Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
]
FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
@ -55,11 +69,11 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC):
raise NotImplementedError
def __init__(
self, c: ConfigT, layer_mapping_function: Callable
self, c: ConfigT, layer_param_names: Sequence[str]
) -> None:
assert self.can_implement(c)
self.config = c
self.layer_mapping_function = layer_mapping_function
self.layer_param_names = layer_param_names
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@ -72,4 +86,53 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer) -> ParamsT:
raise NotImplementedError
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC):
def __init__(
self, c: ConfigT, layer_param_names: Sequence[str]
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=c.activation_group_shape,
num_token_padding=self.get_ouput_padding(),
)
super().__init__(c, layer_param_names)
@abstractmethod
def get_ouput_padding(self) -> int | None:
raise NotImplementedError
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _get_layer_params(self, layer) -> FP8ParamsT:
w, w_s, x_s = self.layer_param_names
return (
getattr(layer, w),
getattr(layer, w_s),
getattr(layer, x_s),
)
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC):
def _get_layer_params(self, layer) -> Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return (
getattr(layer, w_q),
getattr(layer, w_s),
getattr(layer, i_s),
getattr(layer, i_zp),
getattr(layer, azp_adj),
)

View File

@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import process_weights_after_loading
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
@ -52,7 +52,7 @@ if current_platform.is_rocm():
)
class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@ -91,11 +91,6 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
_, param_names = self.layer_mapping_function(layer)
process_weights_after_loading(self.config, layer, *param_names)
def apply_weights(
self,
layer: torch.nn.Module,
@ -112,7 +107,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.

View File

@ -14,10 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@ -30,7 +30,8 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
w_q_name, _, _, _, _ = self.layer_param_names
weight = getattr(layer, w_q_name)
dtype = weight.dtype
N, K = weight.size()
if (
@ -48,10 +49,13 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Transpose to [K, N] for convenience
weight = getattr(layer, self.w_q_name)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
self.layer_param_names
)
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
self.w_q_name,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
@ -60,28 +64,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
input_scale = getattr(layer, i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
@ -91,20 +94,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
@ -112,38 +111,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, self.w_q_name)
weight_scale = getattr(layer, self.w_s_name)
weight = getattr(layer, w_q_name)
weight_scale = getattr(layer, w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze()
setattr(
layer,
self.azp_adj_name,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
weight = getattr(layer, self.w_q_name)
weight = getattr(layer, w_q_name)
self.dnnl_handler = ops.create_onednn_scaled_mm(
weight,
getattr(layer, self.w_s_name),
getattr(layer, w_s_name),
torch.get_default_dtype(),
getattr(layer, self.i_s_name) is None,
getattr(layer, i_s_name) is None,
not self.config.input_symmetric,
32,
)
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
setattr(layer, self.w_q_name, None)
setattr(layer, w_q_name, None)
del weight
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, _, _, _ = self.layer_param_names
# WEIGHT
weight = getattr(layer, self.w_q_name)
weight = getattr(layer, w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
)
if layer.bias is not None:
@ -155,19 +153,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, self.w_s_name)
weight_scale = getattr(layer, w_s_name)
if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
@ -186,7 +180,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
@ -208,7 +202,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
w_q, w_s, _, _, _ = self._get_layer_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant(
x,
w_q,

View File

@ -15,10 +15,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel
from .utils import apply_weights_fp8
def cutlass_w8a8_scaled_mm(
def cutlass_w8a8_scaled_mm_fp8(
*,
A: torch.Tensor,
B: torch.Tensor,
@ -34,91 +34,7 @@ def cutlass_w8a8_scaled_mm(
)
return output.view(*output_shape)
def process_weights_after_loading(
config: Int8ScaledMMLinearLayerConfig,
layer: torch.nn.Module,
w_q_name: str,
w_s_name: str,
i_s_name: str,
i_zp_name: str,
azp_adj_name: str,
):
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if config.input_symmetric:
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, i_zp_name, None)
else:
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not config.input_symmetric:
weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr(
layer,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@ -131,9 +47,83 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
_, param_names = self.layer_mapping_function(layer)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
self.layer_param_names
)
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if config.input_symmetric:
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, i_zp_name, None)
else:
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not config.input_symmetric:
weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr(
layer,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
process_weights_after_loading(self.config, layer, *param_names)
def apply_weights(
self,
@ -141,7 +131,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
@ -170,21 +160,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
)
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
@ -197,41 +176,20 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return cutlass_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
w, w_s, x_s = self._get_layer_params(layer)
return apply_weights_fp8(
cutlass_w8a8_scaled_mm_fp8,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
self.config.out_dtype
)

View File

@ -10,10 +10,11 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
def flashinfer_w8a8_scaled_mm(
@ -30,16 +31,10 @@ def flashinfer_w8a8_scaled_mm(
)
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def get_min_capability(cls) -> int:
@ -80,41 +75,20 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return flashinfer_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
w, w_s, x_s = self._get_layer_params(layer)
return apply_weights_fp8(
flashinfer_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
self.config.out_dtype
)

View File

@ -12,11 +12,11 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
@ -88,20 +88,9 @@ if current_platform.is_rocm():
)
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=None,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
return 90
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
return None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
@ -128,7 +117,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
return (
False,
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
+ "to use ROCmScaledMMLinearKernel ",
+ "to use ROCmScaledMMLinearKernel.",
)
if not (per_tensor_activation_scales and per_tensor_weight_scales):
@ -139,41 +128,20 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return rocm_per_tensor_float_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
w, w_s, x_s = self._get_layer_params(layer)
return apply_weights_fp8(
rocm_per_tensor_float_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
self.config.out_dtype
)

View File

@ -11,11 +11,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from .utils import apply_weights_fp8
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
@ -134,35 +135,16 @@ def torch_channelwise_w8a8_scaled_mm(
return output.to(out_dtype).view(*output_shape)
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
) -> None:
class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None:
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
output_padding = 17 if pad_output else None
self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme,
group_shape=GroupShape.PER_TENSOR,
num_token_padding=output_padding,
)
super().__init__(c, layer_mapping_function)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
return output_padding
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
@ -182,36 +164,18 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_per_tensor_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
w, w_s, x_s = self._get_layer_params(layer)
return apply_weights_fp8(
torch_per_tensor_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
self.config.out_dtype
)
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@ -219,14 +183,12 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
)
if per_tensor_activation_scales and per_tensor_weight_scales:
if per_tensor_activation_scales or per_tensor_weight_scales:
return (
False,
"RowWiseTorchScaledMMLinearKernel cannot be used with "
@ -254,33 +216,16 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_row_wise_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
w, w_s, x_s = self._get_layer_params(layer)
return apply_weights_fp8(
torch_row_wise_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
self.config.out_dtype
)
@ -291,8 +236,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
assert c.activation_group_shape is not None
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
@ -313,31 +256,14 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
out_dtype = self.config.out_dtype
out_dtype = x.dtype if out_dtype is None else out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
)
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
return torch_channelwise_w8a8_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
w, w_s, x_s = self._get_layer_params(layer)
return apply_weights_fp8(
torch_channelwise_w8a8_scaled_mm,
self.quant_fp8,
w,
x,
w_s,
x_s,
bias,
self.config.out_dtype
)

View File

@ -0,0 +1,44 @@
from collections.abc import Callable
import torch
from vllm.platforms import current_platform
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
def apply_weights_fp8(
scaled_mm_func: FP8ScaledMMCallBack,
quant_fp8_func: FP8QuantCallback,
w:torch.Tensor,
x:torch.Tensor,
w_s:torch.Tensor,
x_s:torch.Tensor,
bias:torch.Tensor,
maybe_out_dtype: torch.dtype | None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != current_platform.fp8_dtype():
x_2d_q, x_s = quant_fp8_func(
x_2d,
x_s,
)
return scaled_mm_func(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)

View File

@ -12,10 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
@ -42,9 +42,12 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
weight = getattr(layer, self.w_q_name)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
self.layer_param_names
)
weight = getattr(layer, w_q_name)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
)
# WEIGHT SCALE
@ -52,7 +55,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
@ -60,14 +63,14 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
weight_scale = weight_scale.squeeze(-1)
replace_parameter(
layer,
self.w_s_name,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# Only support symmetric dynamic activation quantization.
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
setattr(layer, azp_adj_name, None)
# Filter warning for cond usage in apply_weights. It is okay
# to specialize the graph since bias is not dynamic.
@ -88,7 +91,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
w_q, w_s, _, _, _ = self._get_layer_params(layer)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401

View File

@ -102,24 +102,27 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
input_zero_point=None
input_scale=None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("input_zero_point", input_zero_point)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
layer_param_names = layer_param_names
)
# Checkpoints are serialized in quark format, which is