mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 20:01:20 +08:00
clean up; fix quark path
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
e54e572085
commit
c05027f67a
@ -6,6 +6,7 @@ from collections.abc import Callable
|
|||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
@ -17,6 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
|
QUANT_STRATEGY_MAP,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
W8A8BlockFp8LinearOp,
|
||||||
@ -49,8 +51,11 @@ strategy_to_parameter_type = {
|
|||||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||||
|
_kernel_backends_being_used: set[str] = set()
|
||||||
|
|
||||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||||
self.weight_quant = weight_quant
|
self.weight_quant = weight_quant
|
||||||
self.strategy = weight_quant.strategy
|
self.strategy = weight_quant.strategy
|
||||||
@ -79,19 +84,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
param_name_list = ["weight", "weight_scale", "input_scale"]
|
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
||||||
layer_mapping_function = lambda layer: (
|
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
|
||||||
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
|
||||||
param_name_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: clean up
|
|
||||||
if self.strategy == QuantizationStrategy.TENSOR:
|
|
||||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
|
||||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
|
||||||
|
|
||||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||||
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
weight_quant_strategy=weight_quant_strategy,
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
activation_group_shape=self.act_q_group_shape,
|
activation_group_shape=self.act_q_group_shape,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
@ -101,9 +97,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
_POSSIBLE_FP8_KERNELS,
|
_POSSIBLE_FP8_KERNELS,
|
||||||
)
|
)
|
||||||
self.fp8_linear = kernel(
|
self.fp8_linear = kernel(
|
||||||
scaled_mm_linear_kernel_config, layer_mapping_function
|
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if kernel.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel.__name__)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# lovelace and up
|
# lovelace and up
|
||||||
|
|||||||
@ -113,15 +113,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
if not hasattr(layer, "azp_adj"):
|
if not hasattr(layer, "azp_adj"):
|
||||||
layer.register_parameter("azp_adj", None)
|
layer.register_parameter("azp_adj", None)
|
||||||
|
|
||||||
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||||
|
|
||||||
layer_mapping_function = lambda layer: (
|
|
||||||
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
|
||||||
param_name_list,
|
|
||||||
)
|
|
||||||
self.kernel = kernel_type(
|
self.kernel = kernel_type(
|
||||||
c=scaled_mm_linear_kernel_config,
|
c=scaled_mm_linear_kernel_config,
|
||||||
layer_mapping_function = layer_mapping_function
|
layer_param_names = layer_param_names
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoints are serialized in compressed-tensors format, which is
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
|
|||||||
@ -5,10 +5,12 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, Sequence, TypeVar
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
|
||||||
|
|
||||||
class ScaledMMLinearQuantStrategy(Enum):
|
class ScaledMMLinearQuantStrategy(Enum):
|
||||||
@ -16,21 +18,19 @@ class ScaledMMLinearQuantStrategy(Enum):
|
|||||||
CHANNEL = "channel"
|
CHANNEL = "channel"
|
||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
def is_per_token(self) -> bool:
|
QUANT_STRATEGY_MAP = {
|
||||||
return self.row == 1 and self.col == -1
|
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
|
||||||
|
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
|
||||||
def is_per_group(self) -> bool:
|
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
|
||||||
return self.row == 1 and self.col >= 1
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScaledMMLinearLayerConfig:
|
class ScaledMMLinearLayerConfig:
|
||||||
pass
|
is_static_input_scheme: bool
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||||
is_channelwise: bool
|
is_channelwise: bool
|
||||||
is_static_input_scheme: bool
|
|
||||||
input_symmetric: bool
|
input_symmetric: bool
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -40,10 +40,24 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
|||||||
out_dtype: torch.dtype
|
out_dtype: torch.dtype
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Int8ParamsT = tuple[
|
||||||
|
torch.Tensor, # weight
|
||||||
|
torch.Tensor, # weight_scale
|
||||||
|
torch.Tensor | None, # input_scale,
|
||||||
|
]
|
||||||
|
FP8ParamsT = tuple[
|
||||||
|
torch.Tensor, # weight
|
||||||
|
torch.Tensor, # weight_scale
|
||||||
|
torch.Tensor | None, # input_scale,
|
||||||
|
torch.Tensor | None, # input_zp
|
||||||
|
torch.Tensor | None, # azp_adj
|
||||||
|
]
|
||||||
|
|
||||||
|
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
|
||||||
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
||||||
|
|
||||||
|
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||||
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -55,11 +69,11 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ConfigT, layer_mapping_function: Callable
|
self, c: ConfigT, layer_param_names: Sequence[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.can_implement(c)
|
assert self.can_implement(c)
|
||||||
self.config = c
|
self.config = c
|
||||||
self.layer_mapping_function = layer_mapping_function
|
self.layer_param_names = layer_param_names
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
@ -73,3 +87,52 @@ class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
|||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# return a covariant type in the subclass
|
||||||
|
@abstractmethod
|
||||||
|
def _get_layer_params(self, layer) -> ParamsT:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC):
|
||||||
|
def __init__(
|
||||||
|
self, c: ConfigT, layer_param_names: Sequence[str]
|
||||||
|
) -> None:
|
||||||
|
self.quant_fp8 = QuantFP8(
|
||||||
|
static=c.is_static_input_scheme,
|
||||||
|
group_shape=c.activation_group_shape,
|
||||||
|
num_token_padding=self.get_ouput_padding(),
|
||||||
|
)
|
||||||
|
super().__init__(c, layer_param_names)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_ouput_padding(self) -> int | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# lovelace and up
|
||||||
|
return 89
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_layer_params(self, layer) -> FP8ParamsT:
|
||||||
|
w, w_s, x_s = self.layer_param_names
|
||||||
|
return (
|
||||||
|
getattr(layer, w),
|
||||||
|
getattr(layer, w_s),
|
||||||
|
getattr(layer, x_s),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC):
|
||||||
|
def _get_layer_params(self, layer) -> Int8ParamsT:
|
||||||
|
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||||
|
return (
|
||||||
|
getattr(layer, w_q),
|
||||||
|
getattr(layer, w_s),
|
||||||
|
getattr(layer, i_s),
|
||||||
|
getattr(layer, i_zp),
|
||||||
|
getattr(layer, azp_adj),
|
||||||
|
)
|
||||||
|
|||||||
@ -9,8 +9,8 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from .cutlass import process_weights_after_loading
|
from .cutlass import CutlassScaledMMLinearKernel
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_gemm_w8a8_impl(
|
def rocm_aiter_gemm_w8a8_impl(
|
||||||
@ -52,7 +52,7 @@ if current_platform.is_rocm():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 90
|
return 90
|
||||||
@ -91,11 +91,6 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
_, param_names = self.layer_mapping_function(layer)
|
|
||||||
|
|
||||||
process_weights_after_loading(self.config, layer, *param_names)
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -112,7 +107,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||||
ATIER block scaled GEMM and mix-precision GEMM.
|
ATIER block scaled GEMM and mix-precision GEMM.
|
||||||
"""
|
"""
|
||||||
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
# * dynamic, i_s is None and x_s computed from x.
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
|
|||||||
@ -14,10 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 75
|
return 75
|
||||||
@ -30,7 +30,8 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
weight = getattr(layer, self.w_q_name)
|
w_q_name, _, _, _, _ = self.layer_param_names
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
dtype = weight.dtype
|
dtype = weight.dtype
|
||||||
N, K = weight.size()
|
N, K = weight.size()
|
||||||
if (
|
if (
|
||||||
@ -48,10 +49,13 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
# Transpose to [K, N] for convenience
|
# Transpose to [K, N] for convenience
|
||||||
weight = getattr(layer, self.w_q_name)
|
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||||
|
self.layer_param_names
|
||||||
|
)
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer,
|
layer,
|
||||||
self.w_q_name,
|
w_q_name,
|
||||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,28 +64,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
# scales being passed to the kernel), convert to the per-channel case.
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
is_fused_module = len(layer.logical_widths) > 1
|
is_fused_module = len(layer.logical_widths) > 1
|
||||||
weight_scale = getattr(layer, self.w_s_name)
|
weight_scale = getattr(layer, w_s_name)
|
||||||
if is_fused_module and not self.config.is_channelwise:
|
if is_fused_module and not self.config.is_channelwise:
|
||||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer,
|
layer,
|
||||||
self.w_s_name,
|
w_s_name,
|
||||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
if self.config.is_static_input_scheme:
|
if self.config.is_static_input_scheme:
|
||||||
input_scale = getattr(layer, self.i_s_name)
|
input_scale = getattr(layer, i_s_name)
|
||||||
|
|
||||||
if self.config.input_symmetric:
|
if self.config.input_symmetric:
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer,
|
layer,
|
||||||
self.i_s_name,
|
i_s_name,
|
||||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||||
)
|
)
|
||||||
setattr(layer, self.i_zp_name, None)
|
|
||||||
else:
|
else:
|
||||||
input_zero_point = getattr(layer, self.i_zp_name)
|
input_zero_point = getattr(layer, i_zp_name)
|
||||||
|
|
||||||
# reconstruct the ranges
|
# reconstruct the ranges
|
||||||
int8_traits = torch.iinfo(torch.int8)
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
@ -91,20 +94,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
|
|
||||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
azp = (
|
azp = (
|
||||||
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
||||||
)
|
)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
setattr(layer, self.i_s_name, None)
|
|
||||||
setattr(layer, self.i_zp_name, None)
|
|
||||||
|
|
||||||
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
||||||
# term for dynamic quantization. And s_b should be folded into the
|
# term for dynamic quantization. And s_b should be folded into the
|
||||||
# term. Such as:
|
# term. Such as:
|
||||||
@ -112,38 +111,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||||
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
||||||
weight = getattr(layer, self.w_q_name)
|
weight = getattr(layer, w_q_name)
|
||||||
weight_scale = getattr(layer, self.w_s_name)
|
weight_scale = getattr(layer, w_s_name)
|
||||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||||
azp_adj = azp_adj * weight_scale.squeeze()
|
azp_adj = azp_adj * weight_scale.squeeze()
|
||||||
setattr(
|
setattr(
|
||||||
layer,
|
layer,
|
||||||
self.azp_adj_name,
|
azp_adj_name,
|
||||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
setattr(layer, self.azp_adj_name, None)
|
|
||||||
|
|
||||||
weight = getattr(layer, self.w_q_name)
|
weight = getattr(layer, w_q_name)
|
||||||
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
||||||
weight,
|
weight,
|
||||||
getattr(layer, self.w_s_name),
|
getattr(layer, w_s_name),
|
||||||
torch.get_default_dtype(),
|
torch.get_default_dtype(),
|
||||||
getattr(layer, self.i_s_name) is None,
|
getattr(layer, i_s_name) is None,
|
||||||
not self.config.input_symmetric,
|
not self.config.input_symmetric,
|
||||||
32,
|
32,
|
||||||
)
|
)
|
||||||
# weight is prepacked and maintained by the dnnl_handler,
|
# weight is prepacked and maintained by the dnnl_handler,
|
||||||
# release the original weight
|
# release the original weight
|
||||||
setattr(layer, self.w_q_name, None)
|
setattr(layer, w_q_name, None)
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
||||||
|
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
weight = getattr(layer, self.w_q_name)
|
weight = getattr(layer, w_q_name)
|
||||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
if layer.bias is not None:
|
if layer.bias is not None:
|
||||||
@ -155,19 +153,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
# CPU SGL kernels only support per-channel.
|
# CPU SGL kernels only support per-channel.
|
||||||
# For per-tensor quant, convert to the per-channel case.
|
# For per-tensor quant, convert to the per-channel case.
|
||||||
weight_scale = getattr(layer, self.w_s_name)
|
weight_scale = getattr(layer, w_s_name)
|
||||||
if not self.config.is_channelwise:
|
if not self.config.is_channelwise:
|
||||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer,
|
layer,
|
||||||
self.w_s_name,
|
w_s_name,
|
||||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
setattr(layer, self.i_s_name, None)
|
|
||||||
setattr(layer, self.i_zp_name, None)
|
|
||||||
setattr(layer, self.azp_adj_name, None)
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -186,7 +180,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
# * dynamic, i_s is None and x_s computed from x.
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
@ -208,7 +202,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||||
x,
|
x,
|
||||||
w_q,
|
w_q,
|
||||||
|
|||||||
@ -15,10 +15,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel
|
||||||
|
from .utils import apply_weights_fp8
|
||||||
|
|
||||||
|
def cutlass_w8a8_scaled_mm_fp8(
|
||||||
def cutlass_w8a8_scaled_mm(
|
|
||||||
*,
|
*,
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -34,91 +34,7 @@ def cutlass_w8a8_scaled_mm(
|
|||||||
)
|
)
|
||||||
return output.view(*output_shape)
|
return output.view(*output_shape)
|
||||||
|
|
||||||
|
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||||
def process_weights_after_loading(
|
|
||||||
config: Int8ScaledMMLinearLayerConfig,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
w_q_name: str,
|
|
||||||
w_s_name: str,
|
|
||||||
i_s_name: str,
|
|
||||||
i_zp_name: str,
|
|
||||||
azp_adj_name: str,
|
|
||||||
):
|
|
||||||
# WEIGHT
|
|
||||||
# Cutlass kernels need transposed weight.
|
|
||||||
weight = getattr(layer, w_q_name)
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
w_q_name,
|
|
||||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
|
||||||
# Cutlass kernels support only per-tensor and per-channel.
|
|
||||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
|
||||||
# scales being passed to the kernel), convert to the per-channel case.
|
|
||||||
is_fused_module = len(layer.logical_widths) > 1
|
|
||||||
weight_scale = getattr(layer, w_s_name)
|
|
||||||
if is_fused_module and not config.is_channelwise:
|
|
||||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
w_s_name,
|
|
||||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
# INPUT SCALE
|
|
||||||
if config.is_static_input_scheme:
|
|
||||||
input_scale = getattr(layer, i_s_name)
|
|
||||||
|
|
||||||
if config.input_symmetric:
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
i_s_name,
|
|
||||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
|
||||||
)
|
|
||||||
setattr(layer, i_zp_name, None)
|
|
||||||
else:
|
|
||||||
input_zero_point = getattr(layer, i_zp_name)
|
|
||||||
|
|
||||||
# reconstruct the ranges
|
|
||||||
int8_traits = torch.iinfo(torch.int8)
|
|
||||||
azps = input_zero_point.to(dtype=torch.int32)
|
|
||||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
|
||||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
|
||||||
|
|
||||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
|
||||||
replace_parameter(
|
|
||||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# AZP loaded as int8 but used as int32
|
|
||||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
|
||||||
replace_parameter(
|
|
||||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
|
||||||
# It does not depend on scales or azp, so it is the same for
|
|
||||||
# static and dynamic quantization.
|
|
||||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
|
||||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
|
||||||
if not config.input_symmetric:
|
|
||||||
weight = getattr(layer, w_q_name)
|
|
||||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
|
||||||
if config.is_static_input_scheme:
|
|
||||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
|
||||||
# in the per-tensor case
|
|
||||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
azp_adj_name,
|
|
||||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 75
|
return 75
|
||||||
@ -131,9 +47,83 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
_, param_names = self.layer_mapping_function(layer)
|
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||||
|
self.layer_param_names
|
||||||
|
)
|
||||||
|
config = self.config
|
||||||
|
# WEIGHT
|
||||||
|
# Cutlass kernels need transposed weight.
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
w_q_name,
|
||||||
|
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
# Cutlass kernels support only per-tensor and per-channel.
|
||||||
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
|
is_fused_module = len(layer.logical_widths) > 1
|
||||||
|
weight_scale = getattr(layer, w_s_name)
|
||||||
|
if is_fused_module and not config.is_channelwise:
|
||||||
|
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
w_s_name,
|
||||||
|
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if config.is_static_input_scheme:
|
||||||
|
input_scale = getattr(layer, i_s_name)
|
||||||
|
|
||||||
|
if config.input_symmetric:
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
i_s_name,
|
||||||
|
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||||
|
)
|
||||||
|
setattr(layer, i_zp_name, None)
|
||||||
|
else:
|
||||||
|
input_zero_point = getattr(layer, i_zp_name)
|
||||||
|
|
||||||
|
# reconstruct the ranges
|
||||||
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
azps = input_zero_point.to(dtype=torch.int32)
|
||||||
|
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||||
|
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||||
|
|
||||||
|
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||||
|
replace_parameter(
|
||||||
|
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# AZP loaded as int8 but used as int32
|
||||||
|
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||||
|
replace_parameter(
|
||||||
|
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||||
|
# It does not depend on scales or azp, so it is the same for
|
||||||
|
# static and dynamic quantization.
|
||||||
|
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||||
|
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||||
|
if not config.input_symmetric:
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
|
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||||
|
if config.is_static_input_scheme:
|
||||||
|
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||||
|
# in the per-tensor case
|
||||||
|
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
azp_adj_name,
|
||||||
|
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
process_weights_after_loading(self.config, layer, *param_names)
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
@ -141,7 +131,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
(w_q, w_s, i_s, i_zp, azp_adj), _ = self.layer_mapping_function(layer)
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
# * dynamic, i_s is None and x_s computed from x.
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
@ -170,21 +160,10 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def __init__(
|
|
||||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
|
||||||
) -> None:
|
|
||||||
self.quant_fp8 = QuantFP8(
|
|
||||||
static=c.is_static_input_scheme,
|
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
num_token_padding=None,
|
|
||||||
)
|
|
||||||
super().__init__(c, layer_mapping_function)
|
|
||||||
|
|
||||||
@classmethod
|
def get_ouput_padding(self) -> int | None:
|
||||||
def get_min_capability(cls) -> int:
|
return None
|
||||||
# lovelace and up
|
|
||||||
return 89
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
@ -197,41 +176,20 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
w, w_s, x_s = self._get_layer_params(layer)
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
return apply_weights_fp8(
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
cutlass_w8a8_scaled_mm_fp8,
|
||||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
self.quant_fp8,
|
||||||
# View input as 2D matrix for fp8 methods
|
w,
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x,
|
||||||
|
w_s,
|
||||||
out_dtype = self.config.out_dtype
|
x_s,
|
||||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
bias,
|
||||||
# If input not quantized
|
self.config.out_dtype
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = self.quant_fp8(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
|
||||||
|
|
||||||
return cutlass_w8a8_scaled_mm(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
)
|
||||||
@ -10,10 +10,11 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
Int8ScaledMMLinearLayerConfig,
|
Int8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
from .utils import apply_weights_fp8
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_w8a8_scaled_mm(
|
def flashinfer_w8a8_scaled_mm(
|
||||||
@ -30,16 +31,10 @@ def flashinfer_w8a8_scaled_mm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def __init__(
|
|
||||||
self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
def get_ouput_padding(self) -> int | None:
|
||||||
) -> None:
|
return None
|
||||||
self.quant_fp8 = QuantFP8(
|
|
||||||
static=c.is_static_input_scheme,
|
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
num_token_padding=None,
|
|
||||||
)
|
|
||||||
super().__init__(c, layer_mapping_function)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -80,41 +75,20 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
w, w_s, x_s = self._get_layer_params(layer)
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
return apply_weights_fp8(
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
flashinfer_w8a8_scaled_mm,
|
||||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
self.quant_fp8,
|
||||||
# View input as 2D matrix for fp8 methods
|
w,
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x,
|
||||||
|
w_s,
|
||||||
out_dtype = self.config.out_dtype
|
x_s,
|
||||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
bias,
|
||||||
# If input not quantized
|
self.config.out_dtype
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = self.quant_fp8(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
|
||||||
|
|
||||||
return flashinfer_w8a8_scaled_mm(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
)
|
||||||
@ -12,11 +12,11 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
from .utils import apply_weights_fp8
|
||||||
|
|
||||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
@ -88,20 +88,9 @@ if current_platform.is_rocm():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def __init__(
|
def get_ouput_padding(self) -> int | None:
|
||||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
return None
|
||||||
) -> None:
|
|
||||||
self.quant_fp8 = QuantFP8(
|
|
||||||
static=c.is_static_input_scheme,
|
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
num_token_padding=None,
|
|
||||||
)
|
|
||||||
super().__init__(c, layer_mapping_function)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
return 90
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
@ -128,7 +117,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
|
"VLLM_ROCM_USE_SKINNY_GEMM must be enabled "
|
||||||
+ "to use ROCmScaledMMLinearKernel ",
|
+ "to use ROCmScaledMMLinearKernel.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||||
@ -139,41 +128,20 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
w, w_s, x_s = self._get_layer_params(layer)
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
return apply_weights_fp8(
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
rocm_per_tensor_float_w8a8_scaled_mm,
|
||||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
self.quant_fp8,
|
||||||
# View input as 2D matrix for fp8 methods
|
w,
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x,
|
||||||
|
w_s,
|
||||||
out_dtype = self.config.out_dtype
|
x_s,
|
||||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
bias,
|
||||||
# If input not quantized
|
self.config.out_dtype
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = self.quant_fp8(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
|
||||||
|
|
||||||
return rocm_per_tensor_float_w8a8_scaled_mm(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,11 +11,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .utils import apply_weights_fp8
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = None
|
TORCH_DEVICE_IDENTITY = None
|
||||||
@ -134,35 +135,16 @@ def torch_channelwise_w8a8_scaled_mm(
|
|||||||
return output.to(out_dtype).view(*output_shape)
|
return output.to(out_dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
def __init__(
|
def get_ouput_padding(self) -> int | None:
|
||||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
|
||||||
) -> None:
|
|
||||||
vllm_config = get_current_vllm_config().compilation_config
|
vllm_config = get_current_vllm_config().compilation_config
|
||||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
output_padding = 17 if pad_output else None
|
output_padding = 17 if pad_output else None
|
||||||
|
return output_padding
|
||||||
self.quant_fp8 = QuantFP8(
|
|
||||||
static=c.is_static_input_scheme,
|
|
||||||
group_shape=GroupShape.PER_TENSOR,
|
|
||||||
num_token_padding=output_padding,
|
|
||||||
)
|
|
||||||
super().__init__(c, layer_mapping_function)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
# lovelace and up
|
|
||||||
return 89
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
assert c.activation_group_shape is not None
|
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
per_tensor_weight_scales = (
|
per_tensor_weight_scales = (
|
||||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
@ -182,36 +164,18 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
w, w_s, x_s = self._get_layer_params(layer)
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
return apply_weights_fp8(
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
torch_per_tensor_w8a8_scaled_mm,
|
||||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
self.quant_fp8,
|
||||||
# View input as 2D matrix for fp8 methods
|
w,
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x,
|
||||||
|
w_s,
|
||||||
out_dtype = self.config.out_dtype
|
x_s,
|
||||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
bias,
|
||||||
|
self.config.out_dtype
|
||||||
# If input not quantized
|
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = self.quant_fp8(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
)
|
|
||||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
|
||||||
return torch_per_tensor_w8a8_scaled_mm(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -219,14 +183,12 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
assert c.activation_group_shape is not None
|
|
||||||
|
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
per_tensor_weight_scales = (
|
per_tensor_weight_scales = (
|
||||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
)
|
)
|
||||||
|
|
||||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"RowWiseTorchScaledMMLinearKernel cannot be used with "
|
"RowWiseTorchScaledMMLinearKernel cannot be used with "
|
||||||
@ -254,33 +216,16 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
w, w_s, x_s = self._get_layer_params(layer)
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
return apply_weights_fp8(
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
torch_row_wise_w8a8_scaled_mm,
|
||||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
self.quant_fp8,
|
||||||
# View input as 2D matrix for fp8 methods
|
w,
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x,
|
||||||
|
w_s,
|
||||||
out_dtype = self.config.out_dtype
|
x_s,
|
||||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
bias,
|
||||||
|
self.config.out_dtype
|
||||||
# If input not quantized
|
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = self.quant_fp8(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
)
|
|
||||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
|
||||||
return torch_row_wise_w8a8_scaled_mm(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -291,8 +236,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
assert c.activation_group_shape is not None
|
|
||||||
|
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
per_tensor_weight_scales = (
|
per_tensor_weight_scales = (
|
||||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
@ -313,31 +256,14 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
w, w_s, x_s = self._get_layer_params(layer)
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
return apply_weights_fp8(
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
torch_channelwise_w8a8_scaled_mm,
|
||||||
(w, w_s, x_s), _ = self.layer_mapping_function(layer)
|
self.quant_fp8,
|
||||||
# View input as 2D matrix for fp8 methods
|
w,
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x,
|
||||||
|
w_s,
|
||||||
out_dtype = self.config.out_dtype
|
x_s,
|
||||||
out_dtype = x.dtype if out_dtype is None else out_dtype
|
bias,
|
||||||
|
self.config.out_dtype
|
||||||
# If input not quantized
|
|
||||||
# TODO(luka) remove this path if not used anymore
|
|
||||||
x_2d_q = x_2d
|
|
||||||
if x.dtype != current_platform.fp8_dtype():
|
|
||||||
x_2d_q, x_s = self.quant_fp8(
|
|
||||||
x_2d,
|
|
||||||
x_s,
|
|
||||||
)
|
|
||||||
output_shape = [*x_2d_q.shape[:-1], w.shape[1]]
|
|
||||||
return torch_channelwise_w8a8_scaled_mm(
|
|
||||||
A=x_2d_q,
|
|
||||||
B=w,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
As=x_s,
|
|
||||||
Bs=w_s,
|
|
||||||
bias=bias,
|
|
||||||
output_shape=output_shape,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -0,0 +1,44 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
import torch
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
|
||||||
|
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
|
||||||
|
def apply_weights_fp8(
|
||||||
|
scaled_mm_func: FP8ScaledMMCallBack,
|
||||||
|
quant_fp8_func: FP8QuantCallback,
|
||||||
|
w:torch.Tensor,
|
||||||
|
x:torch.Tensor,
|
||||||
|
w_s:torch.Tensor,
|
||||||
|
x_s:torch.Tensor,
|
||||||
|
bias:torch.Tensor,
|
||||||
|
maybe_out_dtype: torch.dtype | None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||||
|
|
||||||
|
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||||
|
|
||||||
|
# If input not quantized
|
||||||
|
# TODO(luka) remove this path if not used anymore
|
||||||
|
x_2d_q = x_2d
|
||||||
|
if x.dtype != current_platform.fp8_dtype():
|
||||||
|
x_2d_q, x_s = quant_fp8_func(
|
||||||
|
x_2d,
|
||||||
|
x_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return scaled_mm_func(
|
||||||
|
A=x_2d_q,
|
||||||
|
B=w,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
As=x_s,
|
||||||
|
Bs=w_s,
|
||||||
|
bias=bias,
|
||||||
|
output_shape=output_shape,
|
||||||
|
)
|
||||||
@ -12,10 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -42,9 +42,12 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
# [out, in] (different than cutlass_scaled_mm)
|
# [out, in] (different than cutlass_scaled_mm)
|
||||||
weight = getattr(layer, self.w_q_name)
|
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
||||||
|
self.layer_param_names
|
||||||
|
)
|
||||||
|
weight = getattr(layer, w_q_name)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
@ -52,7 +55,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
# scales being passed to the kernel), convert to the per-channel case.
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
is_fused_module = len(layer.logical_widths) > 1
|
is_fused_module = len(layer.logical_widths) > 1
|
||||||
weight_scale = getattr(layer, self.w_s_name)
|
weight_scale = getattr(layer, w_s_name)
|
||||||
if is_fused_module and not self.config.is_channelwise:
|
if is_fused_module and not self.config.is_channelwise:
|
||||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||||
|
|
||||||
@ -60,14 +63,14 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
weight_scale = weight_scale.squeeze(-1)
|
weight_scale = weight_scale.squeeze(-1)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer,
|
layer,
|
||||||
self.w_s_name,
|
w_s_name,
|
||||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only support symmetric dynamic activation quantization.
|
# Only support symmetric dynamic activation quantization.
|
||||||
setattr(layer, self.i_s_name, None)
|
setattr(layer, i_s_name, None)
|
||||||
setattr(layer, self.i_zp_name, None)
|
setattr(layer, i_zp_name, None)
|
||||||
setattr(layer, self.azp_adj_name, None)
|
setattr(layer, azp_adj_name, None)
|
||||||
|
|
||||||
# Filter warning for cond usage in apply_weights. It is okay
|
# Filter warning for cond usage in apply_weights. It is okay
|
||||||
# to specialize the graph since bias is not dynamic.
|
# to specialize the graph since bias is not dynamic.
|
||||||
@ -88,7 +91,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||||
|
|
||||||
# Required to register custom ops.
|
# Required to register custom ops.
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||||
|
|||||||
@ -102,24 +102,27 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
|
input_zero_point=None
|
||||||
|
input_scale=None
|
||||||
if self.is_static_input_scheme:
|
if self.is_static_input_scheme:
|
||||||
input_scale = BasevLLMParameter(
|
input_scale = BasevLLMParameter(
|
||||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||||
)
|
)
|
||||||
layer.register_parameter("input_scale", input_scale)
|
|
||||||
|
|
||||||
input_zero_point = BasevLLMParameter(
|
input_zero_point = BasevLLMParameter(
|
||||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||||
)
|
)
|
||||||
layer.register_parameter("input_zero_point", input_zero_point)
|
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
layer.register_parameter("input_zero_point", input_zero_point)
|
||||||
|
if not hasattr(layer, "azp_adj"):
|
||||||
|
layer.register_parameter("azp_adj", None)
|
||||||
|
|
||||||
|
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||||
|
|
||||||
self.kernel = kernel_type(
|
self.kernel = kernel_type(
|
||||||
c=scaled_mm_linear_kernel_config,
|
c=scaled_mm_linear_kernel_config,
|
||||||
w_q_param_name="weight",
|
layer_param_names = layer_param_names
|
||||||
w_s_param_name="weight_scale",
|
|
||||||
i_s_param_name="input_scale",
|
|
||||||
i_zp_param_name="input_zero_point",
|
|
||||||
azp_adj_param_name="azp_adj",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoints are serialized in quark format, which is
|
# Checkpoints are serialized in quark format, which is
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user