mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 09:07:03 +08:00
fix int8 path
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
974e6820ce
commit
e54e572085
@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@ -91,10 +91,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=True,
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=self.act_q_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
|
||||
@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
_POSSIBLE_INT8_KERNELS
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -50,13 +51,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config,
|
||||
_POSSIBLE_INT8_KERNELS
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
|
||||
@ -90,12 +94,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point=None
|
||||
input_scale=None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
@ -103,15 +107,21 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||
|
||||
layer_mapping_function = lambda layer: (
|
||||
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
||||
param_name_list,
|
||||
)
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj",
|
||||
layer_mapping_function = layer_mapping_function
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
|
||||
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
@ -25,16 +25,25 @@ class ScaledMMLinearQuantStrategy(Enum):
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
# TODO: remove is channelwise
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
is_channelwise: bool
|
||||
is_static_input_scheme: bool
|
||||
input_symmetric: bool
|
||||
out_dtype: torch.dtype | None
|
||||
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
|
||||
activation_group_shape: GroupShape
|
||||
out_dtype: torch.dtype
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -42,11 +51,11 @@ class ScaledMMLinearKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
self, c: ConfigT, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
@ -63,21 +72,4 @@ class ScaledMMLinearKernel(ABC):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
# def _get_weight_params(
|
||||
# self, layer: torch.nn.Module
|
||||
# ) -> tuple[
|
||||
# torch.Tensor, # weight
|
||||
# torch.Tensor, # weight_scale
|
||||
# torch.Tensor | None, # input_scale,
|
||||
# torch.Tensor | None, # input_zp
|
||||
# torch.Tensor | None, # azp_adj
|
||||
# ]:
|
||||
# return (
|
||||
# getattr(layer, self.w_q_name),
|
||||
# getattr(layer, self.w_s_name),
|
||||
# getattr(layer, self.i_s_name),
|
||||
# getattr(layer, self.i_zp_name),
|
||||
# getattr(layer, self.azp_adj_name),
|
||||
# )
|
||||
raise NotImplementedError
|
||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import process_weights_after_loading
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_impl(
|
||||
@ -58,7 +58,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@ -23,7 +23,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUScaledMM requires running on CPU."
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm(
|
||||
@ -36,7 +36,7 @@ def cutlass_w8a8_scaled_mm(
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
config: Int8ScaledMMLinearLayerConfig,
|
||||
layer: torch.nn.Module,
|
||||
w_q_name: str,
|
||||
w_s_name: str,
|
||||
@ -98,9 +98,6 @@ def process_weights_after_loading(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
else:
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
@ -119,8 +116,6 @@ def process_weights_after_loading(
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@ -129,7 +124,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CutlassScaledMM requires running on CUDA."
|
||||
|
||||
@ -177,7 +172,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
@ -192,7 +187,7 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
@ -32,7 +32,7 @@ def flashinfer_w8a8_scaled_mm(
|
||||
|
||||
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
@ -46,7 +46,7 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
@ -90,7 +90,7 @@ if current_platform.is_rocm():
|
||||
|
||||
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
@ -104,7 +104,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# TODO: check if this causes an issue on non-ROCM platforms
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
|
||||
@ -136,7 +136,7 @@ def torch_channelwise_w8a8_scaled_mm(
|
||||
|
||||
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
def __init__(
|
||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||
) -> None:
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
@ -161,7 +161,7 @@ class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
@ -218,7 +218,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
return 94
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
@ -290,7 +290,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
return 94
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
assert c.activation_group_shape is not None
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@ -16,7 +16,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@ -24,7 +24,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
@ -50,7 +50,7 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user