update quark fp8 path; format

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-30 14:24:19 +00:00
parent c05027f67a
commit c089ea5753
13 changed files with 172 additions and 122 deletions

View File

@ -6,8 +6,8 @@ from collections.abc import Callable
import torch import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
@ -15,10 +15,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS, _POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel, choose_scaled_mm_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
QUANT_STRATEGY_MAP, QUANT_STRATEGY_MAP,
FP8ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
@ -53,6 +52,7 @@ strategy_to_parameter_type = {
logger = init_logger(__name__) logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set() _kernel_backends_being_used: set[str] = set()
@ -92,17 +92,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
activation_group_shape=self.act_q_group_shape, activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype, out_dtype=self.out_dtype,
) )
kernel = choose_scaled_mm_linear_kernel( kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS, _POSSIBLE_FP8_KERNELS,
) )
self.fp8_linear = kernel(
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
)
if kernel.__name__ not in self._kernel_backends_being_used: if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__) logger.info(
self._kernel_backends_being_used.add(kernel.__name__) "Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__
)
self._kernel_backends_being_used.add(kernel_type.__name__)
self.kernel = kernel_type(
scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
@ -217,4 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias, bias=bias,
) )
return self.fp8_linear.apply_weights(layer, x, bias) return self.kernel.apply_weights(layer, x, bias)

View File

@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_INT8_KERNELS,
choose_scaled_mm_linear_kernel, choose_scaled_mm_linear_kernel,
_POSSIBLE_INT8_KERNELS )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Int8ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
@ -20,7 +23,6 @@ from vllm.model_executor.parameter import (
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -58,8 +60,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
) )
kernel_type = choose_scaled_mm_linear_kernel( kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS
_POSSIBLE_INT8_KERNELS
) )
if kernel_type.__name__ not in self._kernel_backends_being_used: if kernel_type.__name__ not in self._kernel_backends_being_used:
@ -94,8 +95,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
input_zero_point=None input_zero_point = None
input_scale=None input_scale = None
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = BasevLLMParameter( input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
@ -113,11 +114,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
if not hasattr(layer, "azp_adj"): if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None) layer.register_parameter("azp_adj", None)
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] layer_param_names = [
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
]
self.kernel = kernel_type( self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
layer_param_names = layer_param_names
) )
# Checkpoints are serialized in compressed-tensors format, which is # Checkpoints are serialized in compressed-tensors format, which is

View File

@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Generic, Sequence, TypeVar from typing import Generic, TypeVar
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
class ScaledMMLinearQuantStrategy(Enum): class ScaledMMLinearQuantStrategy(Enum):
@ -18,21 +19,24 @@ class ScaledMMLinearQuantStrategy(Enum):
CHANNEL = "channel" CHANNEL = "channel"
BLOCK = "block" BLOCK = "block"
QUANT_STRATEGY_MAP = { QUANT_STRATEGY_MAP = {
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR, QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL, QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
} }
@dataclass @dataclass
class ScaledMMLinearLayerConfig: class ScaledMMLinearLayerConfig:
is_static_input_scheme: bool is_static_input_scheme: bool
@dataclass @dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
is_channelwise: bool is_channelwise: bool
input_symmetric: bool input_symmetric: bool
@dataclass @dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy weight_quant_strategy: ScaledMMLinearQuantStrategy
@ -40,22 +44,22 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
out_dtype: torch.dtype out_dtype: torch.dtype
Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
]
FP8ParamsT = tuple[ FP8ParamsT = tuple[
torch.Tensor, # weight torch.Tensor, # weight
torch.Tensor, # weight_scale torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale, torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp ]
torch.Tensor | None, # azp_adj Int8ParamsT = tuple[
] torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT)
ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig)
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC): class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
@classmethod @classmethod
@ -68,9 +72,7 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]: def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError raise NotImplementedError
def __init__( def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None:
self, c: ConfigT, layer_param_names: Sequence[str]
) -> None:
assert self.can_implement(c) assert self.can_implement(c)
self.config = c self.config = c
self.layer_param_names = layer_param_names self.layer_param_names = layer_param_names
@ -87,16 +89,18 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
# return a covariant type in the subclass # return a covariant type in the subclass
@abstractmethod @abstractmethod
def _get_layer_params(self, layer) -> ParamsT: def _get_layer_params(self, layer) -> ParamsT:
raise NotImplementedError raise NotImplementedError
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC): class FP8ScaledMMLinearKernel(
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC
):
def __init__( def __init__(
self, c: ConfigT, layer_param_names: Sequence[str] self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None: ) -> None:
self.quant_fp8 = QuantFP8( self.quant_fp8 = QuantFP8(
static=c.is_static_input_scheme, static=c.is_static_input_scheme,
@ -104,7 +108,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
num_token_padding=self.get_ouput_padding(), num_token_padding=self.get_ouput_padding(),
) )
super().__init__(c, layer_param_names) super().__init__(c, layer_param_names)
@abstractmethod @abstractmethod
def get_ouput_padding(self) -> int | None: def get_ouput_padding(self) -> int | None:
raise NotImplementedError raise NotImplementedError
@ -113,7 +117,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# lovelace and up # lovelace and up
return 89 return 89
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass
@ -126,7 +130,9 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
) )
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC): class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC
):
def _get_layer_params(self, layer) -> Int8ParamsT: def _get_layer_params(self, layer) -> Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return ( return (

View File

@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
ChannelWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel,

View File

@ -14,7 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel): class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@ -49,9 +52,7 @@ class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT # WEIGHT
# Transpose to [K, N] for convenience # Transpose to [K, N] for convenience
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
self.layer_param_names
)
weight = getattr(layer, w_q_name) weight = getattr(layer, w_q_name)
replace_parameter( replace_parameter(
layer, layer,

View File

@ -2,22 +2,24 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, convert_to_channelwise,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
def cutlass_w8a8_scaled_mm_fp8( def cutlass_w8a8_scaled_mm_fp8(
*, *,
A: torch.Tensor, A: torch.Tensor,
@ -34,6 +36,7 @@ def cutlass_w8a8_scaled_mm_fp8(
) )
return output.view(*output_shape) return output.view(*output_shape)
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel): class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
@ -47,9 +50,7 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
self.layer_param_names
)
config = self.config config = self.config
# WEIGHT # WEIGHT
# Cutlass kernels need transposed weight. # Cutlass kernels need transposed weight.
@ -105,7 +106,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False) layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
) )
# azp_adj is the AZP adjustment term, used to account for weights. # azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for # It does not depend on scales or azp, so it is the same for
# static and dynamic quantization. # static and dynamic quantization.
@ -124,7 +124,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
torch.nn.Parameter(azp_adj, requires_grad=False), torch.nn.Parameter(azp_adj, requires_grad=False),
) )
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -161,7 +160,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None: def get_ouput_padding(self) -> int | None:
return None return None
@ -191,5 +189,5 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
self.config.out_dtype self.config.out_dtype,
) )

View File

@ -1,17 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
@ -32,7 +29,6 @@ def flashinfer_w8a8_scaled_mm(
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
def get_ouput_padding(self) -> int | None: def get_ouput_padding(self) -> int | None:
return None return None
@ -41,7 +37,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
return 100 return 100
@classmethod @classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor() per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
per_tensor_weight_scales = ( per_tensor_weight_scales = (
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
@ -90,5 +86,5 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
self.config.out_dtype self.config.out_dtype,
) )

View File

@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@ -18,6 +15,7 @@ from .ScaledMMLinearKernel import (
) )
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
def rocm_per_tensor_float_w8a8_scaled_mm_impl( def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -40,7 +38,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
current_platform.get_cu_count(), current_platform.get_cu_count(),
bias, bias,
) )
# Fallabck # Fallback
else: else:
output = torch._scaled_mm( output = torch._scaled_mm(
A, A,
@ -143,5 +141,5 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
self.config.out_dtype self.config.out_dtype,
) )

View File

@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from packaging import version from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
@ -15,8 +12,8 @@ from .ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
from .utils import apply_weights_fp8 from .utils import apply_weights_fp8
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None TORCH_DEVICE_IDENTITY = None
@ -142,6 +139,7 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
output_padding = 17 if pad_output else None output_padding = 17 if pad_output else None
return output_padding return output_padding
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
@ -173,9 +171,10 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
self.config.out_dtype self.config.out_dtype,
) )
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
@ -199,7 +198,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
return ( return (
False, False,
"RowWiseTorchScaledMMLinearKernel is only supported " "RowWiseTorchScaledMMLinearKernel is only supported "
+ "in ROCm platforms.", + "on ROCm platforms.",
) )
if not version.parse(torch.__version__) >= version.parse("2.7"): if not version.parse(torch.__version__) >= version.parse("2.7"):
@ -225,7 +224,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
self.config.out_dtype self.config.out_dtype,
) )
@ -265,5 +264,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
w_s, w_s,
x_s, x_s,
bias, bias,
self.config.out_dtype self.config.out_dtype,
) )

View File

@ -1,20 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
import torch
import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
FP8ScaledMMCallBack = Callable[..., torch.Tensor] FP8ScaledMMCallBack = Callable[..., torch.Tensor]
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]] FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
def apply_weights_fp8( def apply_weights_fp8(
scaled_mm_func: FP8ScaledMMCallBack, scaled_mm_func: FP8ScaledMMCallBack,
quant_fp8_func: FP8QuantCallback, quant_fp8_func: FP8QuantCallback,
w:torch.Tensor, w: torch.Tensor,
x:torch.Tensor, x: torch.Tensor,
w_s:torch.Tensor, w_s: torch.Tensor,
x_s:torch.Tensor, x_s: torch.Tensor,
bias:torch.Tensor, bias: torch.Tensor,
maybe_out_dtype: torch.dtype | None, maybe_out_dtype: torch.dtype | None,
) -> torch.Tensor: ) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant. # ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x. # If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale. # If static, layer.input_scale is scalar and x_s is input_scale.

View File

@ -12,7 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel): class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@ -42,9 +45,7 @@ class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT # WEIGHT
# [out, in] (different than cutlass_scaled_mm) # [out, in] (different than cutlass_scaled_mm)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = ( w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
self.layer_param_names
)
weight = getattr(layer, w_q_name) weight = getattr(layer, w_q_name)
replace_parameter( replace_parameter(
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)

View File

@ -7,10 +7,18 @@ from typing import Any, cast
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_FP8_KERNELS,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale, requantize_with_max_scale,
) )
@ -23,8 +31,17 @@ from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"] __all__ = ["QuarkW8A8Fp8"]
logger = init_logger(__name__)
QUANT_STRATEGY_MAP = {
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
}
class QuarkW8A8Fp8(QuarkScheme): class QuarkW8A8Fp8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__( def __init__(
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
): ):
@ -41,10 +58,6 @@ class QuarkW8A8Fp8(QuarkScheme):
self.act_quant_group_shape = ( self.act_quant_group_shape = (
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
) )
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape,
)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
@classmethod @classmethod
@ -163,17 +176,32 @@ class QuarkW8A8Fp8(QuarkScheme):
input_scale[:] = torch.finfo(torch.float32).min input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
layer_param_names = ["weight", "weight_scale", "input_scale"]
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=self.is_static_input_scheme,
weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_quant_group_shape,
out_dtype=self.out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config,
_POSSIBLE_FP8_KERNELS,
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
layer_param_names = ["weight", "weight_scale", "input_scale"]
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
)
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply( return self.kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)

View File

@ -7,9 +7,12 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
_POSSIBLE_INT8_KERNELS,
choose_scaled_mm_linear_kernel, choose_scaled_mm_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
Int8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
@ -56,7 +59,9 @@ class QuarkW8A8Int8(QuarkScheme):
input_symmetric=(self.input_symmetric is True), input_symmetric=(self.input_symmetric is True),
) )
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS
)
if kernel_type.__name__ not in self._kernel_backends_being_used: if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
@ -102,8 +107,8 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("weight_zero_point", weight_zero_point) layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE # INPUT SCALE
input_zero_point=None input_zero_point = None
input_scale=None input_scale = None
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = BasevLLMParameter( input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
@ -117,12 +122,17 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("input_zero_point", input_zero_point) layer.register_parameter("input_zero_point", input_zero_point)
if not hasattr(layer, "azp_adj"): if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None) layer.register_parameter("azp_adj", None)
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"] layer_param_names = [
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
]
self.kernel = kernel_type( self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config, c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
layer_param_names = layer_param_names
) )
# Checkpoints are serialized in quark format, which is # Checkpoints are serialized in quark format, which is