mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 18:07:09 +08:00
update quark fp8 path; format
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
c05027f67a
commit
c089ea5753
@ -6,8 +6,8 @@ from collections.abc import Callable
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
@ -15,10 +15,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
QUANT_STRATEGY_MAP,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
@ -53,6 +52,7 @@ strategy_to_parameter_type = {
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
@ -92,17 +92,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
activation_group_shape=self.act_q_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
)
|
||||
kernel = choose_scaled_mm_linear_kernel(
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
)
|
||||
self.fp8_linear = kernel(
|
||||
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
|
||||
)
|
||||
|
||||
if kernel.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__)
|
||||
self._kernel_backends_being_used.add(kernel.__name__)
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info(
|
||||
"Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__
|
||||
)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -217,4 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
_POSSIBLE_INT8_KERNELS
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@ -20,7 +23,6 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -58,8 +60,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_INT8_KERNELS
|
||||
scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
@ -94,8 +95,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point=None
|
||||
input_scale=None
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
@ -113,11 +114,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||
layer_param_names = [
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
]
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
layer_param_names = layer_param_names
|
||||
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
|
||||
@ -2,15 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Generic, Sequence, TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
|
||||
class ScaledMMLinearQuantStrategy(Enum):
|
||||
@ -18,21 +19,24 @@ class ScaledMMLinearQuantStrategy(Enum):
|
||||
CHANNEL = "channel"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
QUANT_STRATEGY_MAP = {
|
||||
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
is_static_input_scheme: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
is_channelwise: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||
@ -40,22 +44,22 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
out_dtype: torch.dtype
|
||||
|
||||
|
||||
|
||||
Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
]
|
||||
FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
]
|
||||
Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
|
||||
ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT)
|
||||
ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
|
||||
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
@classmethod
|
||||
@ -68,9 +72,7 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self, c: ConfigT, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
@ -87,16 +89,18 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer) -> ParamsT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC):
|
||||
class FP8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC
|
||||
):
|
||||
def __init__(
|
||||
self, c: ConfigT, layer_param_names: Sequence[str]
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
@ -104,7 +108,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
|
||||
num_token_padding=self.get_ouput_padding(),
|
||||
)
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
@ -113,7 +117,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
@ -126,7 +130,9 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC):
|
||||
class Int8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC
|
||||
):
|
||||
def _get_layer_params(self, layer) -> Int8ParamsT:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||
return (
|
||||
|
||||
@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
|
||||
@ -14,7 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@ -49,9 +52,7 @@ class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Transpose to [K, N] for convenience
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||
self.layer_param_names
|
||||
)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
|
||||
@ -2,22 +2,24 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm_fp8(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
@ -34,6 +36,7 @@ def cutlass_w8a8_scaled_mm_fp8(
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -47,9 +50,7 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||
self.layer_param_names
|
||||
)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
@ -105,7 +106,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
@ -124,7 +124,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -161,7 +160,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@ -191,5 +189,5 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -1,17 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
@ -32,7 +29,6 @@ def flashinfer_w8a8_scaled_mm(
|
||||
|
||||
|
||||
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@ -41,7 +37,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
@ -90,5 +86,5 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@ -18,6 +15,7 @@ from .ScaledMMLinearKernel import (
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@ -40,7 +38,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
current_platform.get_cu_count(),
|
||||
bias,
|
||||
)
|
||||
# Fallabck
|
||||
# Fallback
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
@ -143,5 +141,5 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
@ -15,8 +12,8 @@ from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
@ -142,6 +139,7 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
output_padding = 17 if pad_output else None
|
||||
return output_padding
|
||||
|
||||
|
||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
@ -173,9 +171,10 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
|
||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -199,7 +198,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel is only supported "
|
||||
+ "in ROCm platforms.",
|
||||
+ "on ROCm platforms.",
|
||||
)
|
||||
|
||||
if not version.parse(torch.__version__) >= version.parse("2.7"):
|
||||
@ -225,7 +224,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@ -265,5 +264,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
self.config.out_dtype,
|
||||
)
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
|
||||
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
|
||||
def apply_weights_fp8(
|
||||
scaled_mm_func: FP8ScaledMMCallBack,
|
||||
quant_fp8_func: FP8QuantCallback,
|
||||
w:torch.Tensor,
|
||||
x:torch.Tensor,
|
||||
w_s:torch.Tensor,
|
||||
x_s:torch.Tensor,
|
||||
bias:torch.Tensor,
|
||||
maybe_out_dtype: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm_func: FP8ScaledMMCallBack,
|
||||
quant_fp8_func: FP8QuantCallback,
|
||||
w: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_s: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
maybe_out_dtype: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
|
||||
@ -12,7 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@ -42,9 +45,7 @@ class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||
self.layer_param_names
|
||||
)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
|
||||
@ -7,10 +7,18 @@ from typing import Any, cast
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@ -23,8 +31,17 @@ from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANT_STRATEGY_MAP = {
|
||||
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
|
||||
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
}
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
|
||||
):
|
||||
@ -41,10 +58,6 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
self.act_quant_group_shape = (
|
||||
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
|
||||
)
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_quant_group_shape,
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
@ -163,17 +176,32 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=self.act_quant_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
)
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
@ -7,9 +7,12 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@ -56,7 +59,9 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
input_symmetric=(self.input_symmetric is True),
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
@ -102,8 +107,8 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point=None
|
||||
input_scale=None
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
@ -117,12 +122,17 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||
|
||||
layer_param_names = [
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
]
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
layer_param_names = layer_param_names
|
||||
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user