mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 09:24:34 +08:00
update quark fp8 path; format
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
c05027f67a
commit
c089ea5753
@ -6,8 +6,8 @@ from collections.abc import Callable
|
|||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
@ -15,10 +15,9 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|||||||
_POSSIBLE_FP8_KERNELS,
|
_POSSIBLE_FP8_KERNELS,
|
||||||
choose_scaled_mm_linear_kernel,
|
choose_scaled_mm_linear_kernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
FP8ScaledMMLinearLayerConfig,
|
|
||||||
ScaledMMLinearQuantStrategy,
|
|
||||||
QUANT_STRATEGY_MAP,
|
QUANT_STRATEGY_MAP,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
W8A8BlockFp8LinearOp,
|
W8A8BlockFp8LinearOp,
|
||||||
@ -53,6 +52,7 @@ strategy_to_parameter_type = {
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||||
_kernel_backends_being_used: set[str] = set()
|
_kernel_backends_being_used: set[str] = set()
|
||||||
|
|
||||||
@ -92,17 +92,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
activation_group_shape=self.act_q_group_shape,
|
activation_group_shape=self.act_q_group_shape,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
)
|
)
|
||||||
kernel = choose_scaled_mm_linear_kernel(
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
scaled_mm_linear_kernel_config,
|
scaled_mm_linear_kernel_config,
|
||||||
_POSSIBLE_FP8_KERNELS,
|
_POSSIBLE_FP8_KERNELS,
|
||||||
)
|
)
|
||||||
self.fp8_linear = kernel(
|
|
||||||
scaled_mm_linear_kernel_config, layer_param_names = layer_param_names
|
|
||||||
)
|
|
||||||
|
|
||||||
if kernel.__name__ not in self._kernel_backends_being_used:
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
logger.info("Using %s for CompressedTensorsW8A8FP8", kernel.__name__)
|
logger.info(
|
||||||
self._kernel_backends_being_used.add(kernel.__name__)
|
"Using %s for CompressedTensorsW8A8FP8", kernel_type.__name__
|
||||||
|
)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
|
self.kernel = kernel_type(
|
||||||
|
scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -217,4 +220,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
return self.kernel.apply_weights(layer, x, bias)
|
||||||
|
|||||||
@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
_POSSIBLE_INT8_KERNELS,
|
||||||
choose_scaled_mm_linear_kernel,
|
choose_scaled_mm_linear_kernel,
|
||||||
_POSSIBLE_INT8_KERNELS
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
Int8ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
@ -20,7 +23,6 @@ from vllm.model_executor.parameter import (
|
|||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -58,8 +60,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
)
|
)
|
||||||
|
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
scaled_mm_linear_kernel_config,
|
scaled_mm_linear_kernel_config, _POSSIBLE_INT8_KERNELS
|
||||||
_POSSIBLE_INT8_KERNELS
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
@ -94,8 +95,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
input_zero_point=None
|
input_zero_point = None
|
||||||
input_scale=None
|
input_scale = None
|
||||||
if self.is_static_input_scheme:
|
if self.is_static_input_scheme:
|
||||||
input_scale = BasevLLMParameter(
|
input_scale = BasevLLMParameter(
|
||||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||||
@ -113,11 +114,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
if not hasattr(layer, "azp_adj"):
|
if not hasattr(layer, "azp_adj"):
|
||||||
layer.register_parameter("azp_adj", None)
|
layer.register_parameter("azp_adj", None)
|
||||||
|
|
||||||
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
layer_param_names = [
|
||||||
|
"weight",
|
||||||
|
"weight_scale",
|
||||||
|
"input_scale",
|
||||||
|
"input_zero_point",
|
||||||
|
"azp_adj",
|
||||||
|
]
|
||||||
|
|
||||||
self.kernel = kernel_type(
|
self.kernel = kernel_type(
|
||||||
c=scaled_mm_linear_kernel_config,
|
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||||
layer_param_names = layer_param_names
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoints are serialized in compressed-tensors format, which is
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
|
|||||||
@ -2,15 +2,16 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Generic, Sequence, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
|
||||||
|
|
||||||
class ScaledMMLinearQuantStrategy(Enum):
|
class ScaledMMLinearQuantStrategy(Enum):
|
||||||
@ -18,21 +19,24 @@ class ScaledMMLinearQuantStrategy(Enum):
|
|||||||
CHANNEL = "channel"
|
CHANNEL = "channel"
|
||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
QUANT_STRATEGY_MAP = {
|
QUANT_STRATEGY_MAP = {
|
||||||
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
|
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
|
||||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
|
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
|
||||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.BLOCK,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScaledMMLinearLayerConfig:
|
class ScaledMMLinearLayerConfig:
|
||||||
is_static_input_scheme: bool
|
is_static_input_scheme: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||||
is_channelwise: bool
|
is_channelwise: bool
|
||||||
input_symmetric: bool
|
input_symmetric: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||||
@ -40,22 +44,22 @@ class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
|||||||
out_dtype: torch.dtype
|
out_dtype: torch.dtype
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Int8ParamsT = tuple[
|
|
||||||
torch.Tensor, # weight
|
|
||||||
torch.Tensor, # weight_scale
|
|
||||||
torch.Tensor | None, # input_scale,
|
|
||||||
]
|
|
||||||
FP8ParamsT = tuple[
|
FP8ParamsT = tuple[
|
||||||
torch.Tensor, # weight
|
torch.Tensor, # weight
|
||||||
torch.Tensor, # weight_scale
|
torch.Tensor, # weight_scale
|
||||||
torch.Tensor | None, # input_scale,
|
torch.Tensor | None, # input_scale,
|
||||||
torch.Tensor | None, # input_zp
|
]
|
||||||
torch.Tensor | None, # azp_adj
|
Int8ParamsT = tuple[
|
||||||
]
|
torch.Tensor, # weight
|
||||||
|
torch.Tensor, # weight_scale
|
||||||
|
torch.Tensor | None, # input_scale,
|
||||||
|
torch.Tensor | None, # input_zp
|
||||||
|
torch.Tensor | None, # azp_adj
|
||||||
|
]
|
||||||
|
|
||||||
|
ParamsT = TypeVar("ParamsT", Int8ParamsT, FP8ParamsT)
|
||||||
|
ConfigT = TypeVar("ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||||
|
|
||||||
ParamsT = TypeVar('ParamsT', Int8ParamsT, FP8ParamsT)
|
|
||||||
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
|
||||||
|
|
||||||
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -68,9 +72,7 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
|||||||
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
|
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, c: ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||||
self, c: ConfigT, layer_param_names: Sequence[str]
|
|
||||||
) -> None:
|
|
||||||
assert self.can_implement(c)
|
assert self.can_implement(c)
|
||||||
self.config = c
|
self.config = c
|
||||||
self.layer_param_names = layer_param_names
|
self.layer_param_names = layer_param_names
|
||||||
@ -87,16 +89,18 @@ class ScaledMMLinearKernel(Generic[ConfigT, ParamsT], ABC):
|
|||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# return a covariant type in the subclass
|
# return a covariant type in the subclass
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_layer_params(self, layer) -> ParamsT:
|
def _get_layer_params(self, layer) -> ParamsT:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC):
|
class FP8ScaledMMLinearKernel(
|
||||||
|
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8ParamsT], ABC
|
||||||
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ConfigT, layer_param_names: Sequence[str]
|
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.quant_fp8 = QuantFP8(
|
self.quant_fp8 = QuantFP8(
|
||||||
static=c.is_static_input_scheme,
|
static=c.is_static_input_scheme,
|
||||||
@ -104,7 +108,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
|
|||||||
num_token_padding=self.get_ouput_padding(),
|
num_token_padding=self.get_ouput_padding(),
|
||||||
)
|
)
|
||||||
super().__init__(c, layer_param_names)
|
super().__init__(c, layer_param_names)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_ouput_padding(self) -> int | None:
|
def get_ouput_padding(self) -> int | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -113,7 +117,7 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
|
|||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# lovelace and up
|
# lovelace and up
|
||||||
return 89
|
return 89
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -126,7 +130,9 @@ class FP8ScaledMMLinearKernel(ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Int8ScaledMMLinearKernel(ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC):
|
class Int8ScaledMMLinearKernel(
|
||||||
|
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, Int8ParamsT], ABC
|
||||||
|
):
|
||||||
def _get_layer_params(self, layer) -> Int8ParamsT:
|
def _get_layer_params(self, layer) -> Int8ParamsT:
|
||||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
|||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||||
ChannelWiseTorchScaledMMLinearKernel,
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
PerTensorTorchScaledMMLinearKernel,
|
PerTensorTorchScaledMMLinearKernel,
|
||||||
|
|||||||
@ -14,7 +14,10 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import (
|
||||||
|
Int8ScaledMMLinearKernel,
|
||||||
|
Int8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||||
@ -49,9 +52,7 @@ class CPUScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
# Transpose to [K, N] for convenience
|
# Transpose to [K, N] for convenience
|
||||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||||
self.layer_param_names
|
|
||||||
)
|
|
||||||
weight = getattr(layer, w_q_name)
|
weight = getattr(layer, w_q_name)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer,
|
layer,
|
||||||
|
|||||||
@ -2,22 +2,24 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearKernel, Int8ScaledMMLinearKernel
|
from .ScaledMMLinearKernel import (
|
||||||
|
FP8ScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
Int8ScaledMMLinearKernel,
|
||||||
|
Int8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
from .utils import apply_weights_fp8
|
from .utils import apply_weights_fp8
|
||||||
|
|
||||||
|
|
||||||
def cutlass_w8a8_scaled_mm_fp8(
|
def cutlass_w8a8_scaled_mm_fp8(
|
||||||
*,
|
*,
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
@ -34,6 +36,7 @@ def cutlass_w8a8_scaled_mm_fp8(
|
|||||||
)
|
)
|
||||||
return output.view(*output_shape)
|
return output.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -47,9 +50,7 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||||
self.layer_param_names
|
|
||||||
)
|
|
||||||
config = self.config
|
config = self.config
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
# Cutlass kernels need transposed weight.
|
# Cutlass kernels need transposed weight.
|
||||||
@ -105,7 +106,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||||
# It does not depend on scales or azp, so it is the same for
|
# It does not depend on scales or azp, so it is the same for
|
||||||
# static and dynamic quantization.
|
# static and dynamic quantization.
|
||||||
@ -124,7 +124,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -161,7 +160,6 @@ class CutlassScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
|
|
||||||
|
|
||||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
|
|
||||||
def get_ouput_padding(self) -> int | None:
|
def get_ouput_padding(self) -> int | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -191,5 +189,5 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
self.config.out_dtype
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,17 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
FP8ScaledMMLinearKernel,
|
FP8ScaledMMLinearKernel,
|
||||||
Int8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
from .utils import apply_weights_fp8
|
from .utils import apply_weights_fp8
|
||||||
@ -32,7 +29,6 @@ def flashinfer_w8a8_scaled_mm(
|
|||||||
|
|
||||||
|
|
||||||
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
|
|
||||||
def get_ouput_padding(self) -> int | None:
|
def get_ouput_padding(self) -> int | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -41,7 +37,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
return 100
|
return 100
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
per_tensor_weight_scales = (
|
per_tensor_weight_scales = (
|
||||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
@ -90,5 +86,5 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
self.config.out_dtype
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -18,6 +15,7 @@ from .ScaledMMLinearKernel import (
|
|||||||
)
|
)
|
||||||
from .utils import apply_weights_fp8
|
from .utils import apply_weights_fp8
|
||||||
|
|
||||||
|
|
||||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -40,7 +38,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
|||||||
current_platform.get_cu_count(),
|
current_platform.get_cu_count(),
|
||||||
bias,
|
bias,
|
||||||
)
|
)
|
||||||
# Fallabck
|
# Fallback
|
||||||
else:
|
else:
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(
|
||||||
A,
|
A,
|
||||||
@ -143,5 +141,5 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
self.config.out_dtype
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm.config import CompilationMode, get_current_vllm_config
|
from vllm.config import CompilationMode, get_current_vllm_config
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
@ -15,8 +12,8 @@ from .ScaledMMLinearKernel import (
|
|||||||
FP8ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .utils import apply_weights_fp8
|
from .utils import apply_weights_fp8
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = None
|
TORCH_DEVICE_IDENTITY = None
|
||||||
@ -142,6 +139,7 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
output_padding = 17 if pad_output else None
|
output_padding = 17 if pad_output else None
|
||||||
return output_padding
|
return output_padding
|
||||||
|
|
||||||
|
|
||||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
@ -173,9 +171,10 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
self.config.out_dtype
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -199,7 +198,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"RowWiseTorchScaledMMLinearKernel is only supported "
|
"RowWiseTorchScaledMMLinearKernel is only supported "
|
||||||
+ "in ROCm platforms.",
|
+ "on ROCm platforms.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not version.parse(torch.__version__) >= version.parse("2.7"):
|
if not version.parse(torch.__version__) >= version.parse("2.7"):
|
||||||
@ -225,7 +224,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
self.config.out_dtype
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -265,5 +264,5 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
self.config.out_dtype
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,20 +1,25 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import torch
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
|
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
|
||||||
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
|
FP8QuantCallback = Callable[..., tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
def apply_weights_fp8(
|
def apply_weights_fp8(
|
||||||
scaled_mm_func: FP8ScaledMMCallBack,
|
scaled_mm_func: FP8ScaledMMCallBack,
|
||||||
quant_fp8_func: FP8QuantCallback,
|
quant_fp8_func: FP8QuantCallback,
|
||||||
w:torch.Tensor,
|
w: torch.Tensor,
|
||||||
x:torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_s:torch.Tensor,
|
w_s: torch.Tensor,
|
||||||
x_s:torch.Tensor,
|
x_s: torch.Tensor,
|
||||||
bias:torch.Tensor,
|
bias: torch.Tensor,
|
||||||
maybe_out_dtype: torch.dtype | None,
|
maybe_out_dtype: torch.dtype | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||||
|
|||||||
@ -12,7 +12,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import (
|
||||||
|
Int8ScaledMMLinearKernel,
|
||||||
|
Int8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||||
@ -42,9 +45,7 @@ class XLAScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
# [out, in] (different than cutlass_scaled_mm)
|
# [out, in] (different than cutlass_scaled_mm)
|
||||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = (
|
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||||
self.layer_param_names
|
|
||||||
)
|
|
||||||
weight = getattr(layer, w_q_name)
|
weight = getattr(layer, w_q_name)
|
||||||
replace_parameter(
|
replace_parameter(
|
||||||
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
layer, w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
|
||||||
|
|||||||
@ -7,10 +7,18 @@ from typing import Any, cast
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
_POSSIBLE_FP8_KERNELS,
|
||||||
|
choose_scaled_mm_linear_kernel,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
@ -23,8 +31,17 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
__all__ = ["QuarkW8A8Fp8"]
|
__all__ = ["QuarkW8A8Fp8"]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
QUANT_STRATEGY_MAP = {
|
||||||
|
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
|
||||||
|
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class QuarkW8A8Fp8(QuarkScheme):
|
class QuarkW8A8Fp8(QuarkScheme):
|
||||||
|
_kernel_backends_being_used: set[str] = set()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
|
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
|
||||||
):
|
):
|
||||||
@ -41,10 +58,6 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
self.act_quant_group_shape = (
|
self.act_quant_group_shape = (
|
||||||
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
|
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
|
||||||
)
|
)
|
||||||
self.fp8_linear = Fp8LinearOp(
|
|
||||||
act_quant_static=self.is_static_input_scheme,
|
|
||||||
act_quant_group_shape=self.act_quant_group_shape,
|
|
||||||
)
|
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -163,17 +176,32 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
input_scale[:] = torch.finfo(torch.float32).min
|
input_scale[:] = torch.finfo(torch.float32).min
|
||||||
layer.register_parameter("input_scale", input_scale)
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
||||||
|
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
|
||||||
|
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||||
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
|
activation_group_shape=self.act_quant_group_shape,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
)
|
||||||
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config,
|
||||||
|
_POSSIBLE_FP8_KERNELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for QuarkW8A8FP8", kernel_type.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
|
layer_param_names = ["weight", "weight_scale", "input_scale"]
|
||||||
|
self.kernel = kernel_type(
|
||||||
|
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||||
|
)
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.fp8_linear.apply(
|
return self.kernel.apply_weights(layer, x, bias)
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
out_dtype=self.out_dtype,
|
|
||||||
input_scale=layer.input_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -7,9 +7,12 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
_POSSIBLE_INT8_KERNELS,
|
||||||
choose_scaled_mm_linear_kernel,
|
choose_scaled_mm_linear_kernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
Int8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
@ -56,7 +59,9 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
input_symmetric=(self.input_symmetric is True),
|
input_symmetric=(self.input_symmetric is True),
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config, possible_kernels=_POSSIBLE_INT8_KERNELS
|
||||||
|
)
|
||||||
|
|
||||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||||
@ -102,8 +107,8 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
input_zero_point=None
|
input_zero_point = None
|
||||||
input_scale=None
|
input_scale = None
|
||||||
if self.is_static_input_scheme:
|
if self.is_static_input_scheme:
|
||||||
input_scale = BasevLLMParameter(
|
input_scale = BasevLLMParameter(
|
||||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||||
@ -117,12 +122,17 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
layer.register_parameter("input_zero_point", input_zero_point)
|
layer.register_parameter("input_zero_point", input_zero_point)
|
||||||
if not hasattr(layer, "azp_adj"):
|
if not hasattr(layer, "azp_adj"):
|
||||||
layer.register_parameter("azp_adj", None)
|
layer.register_parameter("azp_adj", None)
|
||||||
|
|
||||||
layer_param_names = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
layer_param_names = [
|
||||||
|
"weight",
|
||||||
|
"weight_scale",
|
||||||
|
"input_scale",
|
||||||
|
"input_zero_point",
|
||||||
|
"azp_adj",
|
||||||
|
]
|
||||||
|
|
||||||
self.kernel = kernel_type(
|
self.kernel = kernel_type(
|
||||||
c=scaled_mm_linear_kernel_config,
|
c=scaled_mm_linear_kernel_config, layer_param_names=layer_param_names
|
||||||
layer_param_names = layer_param_names
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoints are serialized in quark format, which is
|
# Checkpoints are serialized in quark format, which is
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user