mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:07:08 +08:00
fix int8 path
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
974e6820ce
commit
e54e572085
@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|||||||
choose_scaled_mm_linear_kernel,
|
choose_scaled_mm_linear_kernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
@ -91,10 +91,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||||
|
|
||||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
|
||||||
is_static_input_scheme=self.is_static_input_scheme,
|
|
||||||
input_symmetric=True,
|
|
||||||
weight_quant_strategy=weight_quant_strategy,
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
activation_group_shape=self.act_q_group_shape,
|
activation_group_shape=self.act_q_group_shape,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
|
|||||||
@ -11,8 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
ScaledMMLinearLayerConfig,
|
|
||||||
choose_scaled_mm_linear_kernel,
|
choose_scaled_mm_linear_kernel,
|
||||||
|
_POSSIBLE_INT8_KERNELS
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (
|
|||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -50,13 +51,16 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
):
|
):
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
|
||||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||||
is_static_input_scheme=self.is_static_input_scheme,
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
input_symmetric=self.input_symmetric,
|
input_symmetric=self.input_symmetric,
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config,
|
||||||
|
_POSSIBLE_INT8_KERNELS
|
||||||
|
)
|
||||||
|
|
||||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
|
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
|
||||||
@ -90,12 +94,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
|
input_zero_point=None
|
||||||
|
input_scale=None
|
||||||
if self.is_static_input_scheme:
|
if self.is_static_input_scheme:
|
||||||
input_scale = BasevLLMParameter(
|
input_scale = BasevLLMParameter(
|
||||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||||
)
|
)
|
||||||
layer.register_parameter("input_scale", input_scale)
|
|
||||||
|
|
||||||
if not self.input_symmetric:
|
if not self.input_symmetric:
|
||||||
# Note: compressed-tensors stores the zp using the same dtype
|
# Note: compressed-tensors stores the zp using the same dtype
|
||||||
# as the weights
|
# as the weights
|
||||||
@ -103,15 +107,21 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
input_zero_point = BasevLLMParameter(
|
input_zero_point = BasevLLMParameter(
|
||||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||||
)
|
)
|
||||||
layer.register_parameter("input_zero_point", input_zero_point)
|
|
||||||
|
|
||||||
|
layer.register_parameter("input_zero_point", input_zero_point)
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
if not hasattr(layer, "azp_adj"):
|
||||||
|
layer.register_parameter("azp_adj", None)
|
||||||
|
|
||||||
|
param_name_list = ["weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj"]
|
||||||
|
|
||||||
|
layer_mapping_function = lambda layer: (
|
||||||
|
tuple(getattr(layer, param_name) for param_name in param_name_list),
|
||||||
|
param_name_list,
|
||||||
|
)
|
||||||
self.kernel = kernel_type(
|
self.kernel = kernel_type(
|
||||||
c=scaled_mm_linear_kernel_config,
|
c=scaled_mm_linear_kernel_config,
|
||||||
w_q_param_name="weight",
|
layer_mapping_function = layer_mapping_function
|
||||||
w_s_param_name="weight_scale",
|
|
||||||
i_s_param_name="input_scale",
|
|
||||||
i_zp_param_name="input_zero_point",
|
|
||||||
azp_adj_param_name="azp_adj",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Checkpoints are serialized in compressed-tensors format, which is
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Generic, TypeVar
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
@ -25,16 +25,25 @@ class ScaledMMLinearQuantStrategy(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScaledMMLinearLayerConfig:
|
class ScaledMMLinearLayerConfig:
|
||||||
# TODO: remove is channelwise
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||||
is_channelwise: bool
|
is_channelwise: bool
|
||||||
is_static_input_scheme: bool
|
is_static_input_scheme: bool
|
||||||
input_symmetric: bool
|
input_symmetric: bool
|
||||||
out_dtype: torch.dtype | None
|
|
||||||
|
@dataclass
|
||||||
|
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||||
activation_group_shape: GroupShape | None = GroupShape.PER_TENSOR
|
activation_group_shape: GroupShape
|
||||||
|
out_dtype: torch.dtype
|
||||||
|
|
||||||
|
|
||||||
class ScaledMMLinearKernel(ABC):
|
ConfigT = TypeVar('ConfigT', bound=ScaledMMLinearLayerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledMMLinearKernel(Generic[ConfigT], ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -42,11 +51,11 @@ class ScaledMMLinearKernel(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: ConfigT) -> tuple[bool, str | None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
self, c: ConfigT, layer_mapping_function: Callable
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.can_implement(c)
|
assert self.can_implement(c)
|
||||||
self.config = c
|
self.config = c
|
||||||
@ -64,20 +73,3 @@ class ScaledMMLinearKernel(ABC):
|
|||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# def _get_weight_params(
|
|
||||||
# self, layer: torch.nn.Module
|
|
||||||
# ) -> tuple[
|
|
||||||
# torch.Tensor, # weight
|
|
||||||
# torch.Tensor, # weight_scale
|
|
||||||
# torch.Tensor | None, # input_scale,
|
|
||||||
# torch.Tensor | None, # input_zp
|
|
||||||
# torch.Tensor | None, # azp_adj
|
|
||||||
# ]:
|
|
||||||
# return (
|
|
||||||
# getattr(layer, self.w_q_name),
|
|
||||||
# getattr(layer, self.w_s_name),
|
|
||||||
# getattr(layer, self.i_s_name),
|
|
||||||
# getattr(layer, self.i_zp_name),
|
|
||||||
# getattr(layer, self.azp_adj_name),
|
|
||||||
# )
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
|||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
ScaledMMLinearLayerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||||
ChannelWiseTorchScaledMMLinearKernel,
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
PerTensorTorchScaledMMLinearKernel,
|
PerTensorTorchScaledMMLinearKernel,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
from .cutlass import process_weights_after_loading
|
from .cutlass import process_weights_after_loading
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_gemm_w8a8_impl(
|
def rocm_aiter_gemm_w8a8_impl(
|
||||||
@ -58,7 +58,7 @@ class AiterScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return 90
|
return 90
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_rocm():
|
if not current_platform.is_rocm():
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@ -23,7 +23,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return 75
|
return 75
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_cpu():
|
if not current_platform.is_cpu():
|
||||||
return False, "CPUScaledMM requires running on CPU."
|
return False, "CPUScaledMM requires running on CPU."
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
def cutlass_w8a8_scaled_mm(
|
def cutlass_w8a8_scaled_mm(
|
||||||
@ -36,7 +36,7 @@ def cutlass_w8a8_scaled_mm(
|
|||||||
|
|
||||||
|
|
||||||
def process_weights_after_loading(
|
def process_weights_after_loading(
|
||||||
config: ScaledMMLinearLayerConfig,
|
config: Int8ScaledMMLinearLayerConfig,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
w_q_name: str,
|
w_q_name: str,
|
||||||
w_s_name: str,
|
w_s_name: str,
|
||||||
@ -98,9 +98,6 @@ def process_weights_after_loading(
|
|||||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
setattr(layer, i_s_name, None)
|
|
||||||
setattr(layer, i_zp_name, None)
|
|
||||||
|
|
||||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||||
# It does not depend on scales or azp, so it is the same for
|
# It does not depend on scales or azp, so it is the same for
|
||||||
@ -119,8 +116,6 @@ def process_weights_after_loading(
|
|||||||
azp_adj_name,
|
azp_adj_name,
|
||||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
setattr(layer, azp_adj_name, None)
|
|
||||||
|
|
||||||
|
|
||||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@ -129,7 +124,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return 75
|
return 75
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
return False, "CutlassScaledMM requires running on CUDA."
|
return False, "CutlassScaledMM requires running on CUDA."
|
||||||
|
|
||||||
@ -177,7 +172,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
|
|
||||||
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
) -> None:
|
) -> None:
|
||||||
self.quant_fp8 = QuantFP8(
|
self.quant_fp8 = QuantFP8(
|
||||||
static=c.is_static_input_scheme,
|
static=c.is_static_input_scheme,
|
||||||
@ -192,7 +187,7 @@ class CutlassFP8ScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return 89
|
return 89
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
|||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
Int8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ def flashinfer_w8a8_scaled_mm(
|
|||||||
|
|
||||||
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
self, c: Int8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
) -> None:
|
) -> None:
|
||||||
self.quant_fp8 = QuantFP8(
|
self.quant_fp8 = QuantFP8(
|
||||||
static=c.is_static_input_scheme,
|
static=c.is_static_input_scheme,
|
||||||
@ -46,7 +46,7 @@ class FlashInferScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return 100
|
return 100
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
per_tensor_weight_scales = (
|
per_tensor_weight_scales = (
|
||||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
|||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ if current_platform.is_rocm():
|
|||||||
|
|
||||||
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
) -> None:
|
) -> None:
|
||||||
self.quant_fp8 = QuantFP8(
|
self.quant_fp8 = QuantFP8(
|
||||||
static=c.is_static_input_scheme,
|
static=c.is_static_input_scheme,
|
||||||
@ -104,7 +104,7 @@ class ROCmScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
return 90
|
return 90
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
# TODO: check if this causes an issue on non-ROCM platforms
|
# TODO: check if this causes an issue on non-ROCM platforms
|
||||||
from vllm.platforms.rocm import on_mi3xx
|
from vllm.platforms.rocm import on_mi3xx
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
from .ScaledMMLinearKernel import (
|
from .ScaledMMLinearKernel import (
|
||||||
ScaledMMLinearKernel,
|
ScaledMMLinearKernel,
|
||||||
ScaledMMLinearLayerConfig,
|
FP8ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ def torch_channelwise_w8a8_scaled_mm(
|
|||||||
|
|
||||||
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, c: ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
self, c: FP8ScaledMMLinearLayerConfig, layer_mapping_function: Callable
|
||||||
) -> None:
|
) -> None:
|
||||||
vllm_config = get_current_vllm_config().compilation_config
|
vllm_config = get_current_vllm_config().compilation_config
|
||||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||||
@ -161,7 +161,7 @@ class TorchScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
|
|
||||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
assert c.activation_group_shape is not None
|
assert c.activation_group_shape is not None
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
per_tensor_weight_scales = (
|
per_tensor_weight_scales = (
|
||||||
@ -218,7 +218,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
return 94
|
return 94
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
assert c.activation_group_shape is not None
|
assert c.activation_group_shape is not None
|
||||||
|
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
@ -290,7 +290,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
return 94
|
return 94
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
assert c.activation_group_shape is not None
|
assert c.activation_group_shape is not None
|
||||||
|
|
||||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .cutlass import CutlassScaledMMLinearKernel
|
from .cutlass import CutlassScaledMMLinearKernel
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||||
@ -16,7 +16,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
return 75
|
return 75
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if current_platform.is_cpu():
|
if current_platform.is_cpu():
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@ -24,7 +24,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_tpu():
|
if not current_platform.is_tpu():
|
||||||
return False, "ScaledMMXLA requires running on TPU."
|
return False, "ScaledMMXLA requires running on TPU."
|
||||||
|
|
||||||
|
|||||||
@ -7,9 +7,9 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
ScaledMMLinearLayerConfig,
|
|
||||||
choose_scaled_mm_linear_kernel,
|
choose_scaled_mm_linear_kernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
from vllm.model_executor.parameter import (
|
from vllm.model_executor.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
@ -50,7 +50,7 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
):
|
):
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
scaled_mm_linear_kernel_config = Int8ScaledMMLinearLayerConfig(
|
||||||
is_channelwise=(self.qscheme == "per_channel"),
|
is_channelwise=(self.qscheme == "per_channel"),
|
||||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||||
input_symmetric=(self.input_symmetric is True),
|
input_symmetric=(self.input_symmetric is True),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user