mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 07:07:05 +08:00
clean up; fix quark path
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
e54e572085
commit
c05027f67a
@ -6,6 +6,7 @@ from collections.abc import Callable
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
@ -17,6 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
QUANT_STRATEGY_MAP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
@ -49,8 +51,11 @@ strategy_to_parameter_type = {
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
@ -79,19 +84,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
param_name_list = ["weight", "weight_scale", "input_scale"]
|
||||
layer_mapping_function = lambda layer: (
|
||||
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
||||
param_name_list,
|
||||
)
|
||||
|
||||
# TODO: clean up
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=self.act_q_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
@ -101,9 +97,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
_POSSIBLE_FP8_KERNELS,
|
||||
)
|
||||
self.fp8_linear = kernel(
|
||||
scaled_mm_linear_kernel_config, layer_mapping_function
|
||||
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
|
||||
)
|
||||
|
||||
if kernel.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__)
|
||||
self._kernel_backends_being_used.add(kernel.__name__)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
|
||||
@ -113,15 +113,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||
|
||||
layer_mapping_function = lambda layer: (
|
||||
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
||||
param_name_list,
|
||||
)
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
layer_mapping_function = layer_mapping_function
|
||||
layer_param_names = layer_param_names
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
|
||||
@ -5,10 +5,12 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, Sequence, TypeVar
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
|
||||
|
||||
class ScaledMMLinearQuantStrategy(Enum):
|
||||
@ -16,21 +18,19 @@ class ScaledMMLinearQuantStrategy(Enum):
|
||||
CHANNEL = "channel"
|
||||
BLOCK = "block"
|
||||
|
||||
def is_per_token(self) -> bool:
|
||||
return self.row == 1 and self.col == -1
|
||||
|
||||
def is_per_group(self) -> bool:
|
||||
return self.row == 1 and self.col >= 1
|
||||
|
||||
QUANT_STRATEGY_MAP = {
|
||||
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
pass
|
||||
is_static_input_scheme: bool
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
is_channelwise: bool
|
||||
is_static_input_scheme: bool
|
||||
input_symmetric: bool
|
||||
|
||||
@dataclass
|
||||
@ -40,10 +40,24 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
out_dtype: torch.dtype
|
||||
|
||||
|
||||
|
||||
Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
]
|
||||
FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
|
||||
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
|
||||
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
||||
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -55,11 +69,11 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self, c: ConfigT, layer_mapping_function: Callable
|
||||
self, c: ConfigT, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.layer_mapping_function = layer_mapping_function
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@ -72,4 +86,53 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer) -> ParamsT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC):
|
||||
def __init__(
|
||||
self, c: ConfigT, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=c.activation_group_shape,
|
||||
num_token_padding=self.get_ouput_padding(),
|
||||
)
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
@abstractmethod
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def _get_layer_params(self, layer) -> FP8ParamsT:
|
||||
w, w_s, x_s = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, x_s),
|
||||
)
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC):
|
||||
def _get_layer_params(self, layer) -> Int8ParamsT:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w_q),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, i_s),
|
||||
getattr(layer, i_zp),
|
||||
getattr(layer, azp_adj),
|
||||
)
|
||||
|
||||
@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import process_weights_after_loading
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_impl(
|
||||
@ -52,7 +52,7 @@ if current_platform.is_rocm():
|
||||
)
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
@ -91,11 +91,6 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
_, param_names = self.layer_mapping_function(layer)
|
||||
|
||||
process_weights_after_loading(self.config, layer, *param_names)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -112,7 +107,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
|
||||
@ -14,10 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
@ -30,7 +30,8 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q_name, _, _, _, _ = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
dtype = weight.dtype
|
||||
N, K = weight.size()
|
||||
if (
|
||||
@ -48,10 +49,13 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Transpose to [K, N] for convenience
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||
self.layer_param_names
|
||||
)
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
@ -60,28 +64,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
@ -91,20 +94,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
azp = (
|
||||
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
||||
# term for dynamic quantization. And s_b should be folded into the
|
||||
# term. Such as:
|
||||
@ -112,38 +111,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
azp_adj = azp_adj * weight_scale.squeeze()
|
||||
setattr(
|
||||
layer,
|
||||
self.azp_adj_name,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
||||
weight,
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, w_s_name),
|
||||
torch.get_default_dtype(),
|
||||
getattr(layer, self.i_s_name) is None,
|
||||
getattr(layer, i_s_name) is None,
|
||||
not self.config.input_symmetric,
|
||||
32,
|
||||
)
|
||||
# weight is prepacked and maintained by the dnnl_handler,
|
||||
# release the original weight
|
||||
setattr(layer, self.w_q_name, None)
|
||||
setattr(layer, w_q_name, None)
|
||||
del weight
|
||||
|
||||
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||
# WEIGHT
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
weight = getattr(layer, w_q_name)
|
||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
)
|
||||
|
||||
if layer.bias is not None:
|
||||
@ -155,19 +153,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# WEIGHT SCALE
|
||||
# CPU SGL kernels only support per-channel.
|
||||
# For per-tensor quant, convert to the per-channel case.
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -186,7 +180,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
@ -208,7 +202,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
x,
|
||||
w_q,
|
||||
|
||||
@ -15,10 +15,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm(
|
||||
def cutlass_w8a8_scaled_mm_fp8(
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@ -34,91 +34,7 @@ def cutlass_w8a8_scaled_mm(
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
config: Int8ScaledMMLinearLayerConfig,
|
||||
layer: torch.nn.Module,
|
||||
w_q_name: str,
|
||||
w_s_name: str,
|
||||
i_s_name: str,
|
||||
i_zp_name: str,
|
||||
azp_adj_name: str,
|
||||
):
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
@ -131,9 +47,83 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
_, param_names = self.layer_mapping_function(layer)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||
self.layer_param_names
|
||||
)
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
process_weights_after_loading(self.config, layer, *param_names)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@ -141,7 +131,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
@ -170,21 +160,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=None,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
@ -197,41 +176,20 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
|
||||
return cutlass_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
cutlass_w8a8_scaled_mm_fp8,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
@ -10,10 +10,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(
|
||||
@ -30,16 +31,10 @@ def flashinfer_w8a8_scaled_mm(
|
||||
)
|
||||
|
||||
|
||||
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=None,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -80,41 +75,20 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
|
||||
return flashinfer_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
flashinfer_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
@ -12,11 +12,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A: torch.Tensor,
|
||||
@ -88,20 +88,9 @@ if current_platform.is_rocm():
|
||||
)
|
||||
|
||||
|
||||
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=None,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
@ -128,7 +117,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return (
|
||||
False,
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
|
||||
+ "to use ROCmScaledMMLinearKernel ",
|
||||
+ "to use ROCmScaledMMLinearKernel.",
|
||||
)
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
@ -139,41 +128,20 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
|
||||
return rocm_per_tensor_float_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
rocm_per_tensor_float_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
|
||||
@ -11,11 +11,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
from .utils import apply_weights_fp8
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
@ -134,35 +135,16 @@ def torch_channelwise_w8a8_scaled_mm(
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
|
||||
|
||||
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
def get_ouput_padding(self) -> int | None:
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
|
||||
output_padding = 17 if pad_output else None
|
||||
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
num_token_padding=output_padding,
|
||||
)
|
||||
super().__init__(c, layer_mapping_function)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
|
||||
return output_padding
|
||||
|
||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
@ -182,36 +164,18 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
return torch_per_tensor_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_per_tensor_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
|
||||
|
||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -219,14 +183,12 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel cannot be used with "
|
||||
@ -254,33 +216,16 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
return torch_row_wise_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_row_wise_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
|
||||
|
||||
@ -291,8 +236,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
@ -313,31 +256,14 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
out_dtype = self.config.out_dtype
|
||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
||||
return torch_channelwise_w8a8_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
w, w_s, x_s = self._get_layer_params(layer)
|
||||
return apply_weights_fp8(
|
||||
torch_channelwise_w8a8_scaled_mm,
|
||||
self.quant_fp8,
|
||||
w,
|
||||
x,
|
||||
w_s,
|
||||
x_s,
|
||||
bias,
|
||||
self.config.out_dtype
|
||||
)
|
||||
|
||||
@ -0,0 +1,44 @@
|
||||
from collections.abc import Callable
|
||||
import torch
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
|
||||
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
def apply_weights_fp8(
|
||||
scaled_mm_func: FP8ScaledMMCallBack,
|
||||
quant_fp8_func: FP8QuantCallback,
|
||||
w:torch.Tensor,
|
||||
x:torch.Tensor,
|
||||
w_s:torch.Tensor,
|
||||
x_s:torch.Tensor,
|
||||
bias:torch.Tensor,
|
||||
maybe_out_dtype: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||
|
||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != current_platform.fp8_dtype():
|
||||
x_2d_q, x_s = quant_fp8_func(
|
||||
x_2d,
|
||||
x_s,
|
||||
)
|
||||
|
||||
return scaled_mm_func(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
@ -12,10 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
@ -42,9 +42,12 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# [out, in] (different than cutlass_scaled_mm)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||
self.layer_param_names
|
||||
)
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
@ -52,7 +55,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
|
||||
@ -60,14 +63,14 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
weight_scale = weight_scale.squeeze(-1)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_s_name,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# Only support symmetric dynamic activation quantization.
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
# Filter warning for cond usage in apply_weights. It is okay
|
||||
# to specialize the graph since bias is not dynamic.
|
||||
@ -88,7 +91,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
@ -102,24 +102,27 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point=None
|
||||
input_scale=None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj",
|
||||
layer_param_names = layer_param_names
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user