mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:37:03 +08:00
prefer QuantKey over ScaledMMLinearQuantStrategy
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
a8010c7b1c
commit
f5e6cd9695
@ -23,10 +23,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -37,6 +36,8 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestSiluMul(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@ -46,9 +47,8 @@ class TestSiluMul(torch.nn.Module):
|
||||
if TEST_FP8:
|
||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
@ -74,6 +74,8 @@ class TestSiluMul(torch.nn.Module):
|
||||
|
||||
|
||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -89,9 +91,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
if TEST_FP8:
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -21,9 +21,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
@ -59,10 +56,16 @@ class TestModel(torch.nn.Module):
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
act_quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
w_quant_scale = ScaleDesc(torch.float32, True, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale, symmetric=True
|
||||
)
|
||||
self.weight_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=w_quant_scale, symmetric=True
|
||||
)
|
||||
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
else:
|
||||
@ -74,9 +77,8 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=static,
|
||||
act_q_group_shape=group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
@ -110,13 +112,13 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return (
|
||||
[QUANT_OPS[self.quant_key]]
|
||||
[QUANT_OPS[self.activation_quant_key]]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else [torch.ops.aten.reciprocal]
|
||||
)
|
||||
|
||||
@ -29,11 +29,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
GroupShape,
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
@ -80,6 +77,8 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -95,9 +94,8 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
]
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -31,9 +31,6 @@ from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
@ -177,18 +174,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_tensor():
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||
else:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=self.quant_key.scale.static,
|
||||
act_q_group_shape=self.quant_key.scale.group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w",
|
||||
|
||||
@ -30,10 +30,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
@ -101,6 +100,8 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -114,9 +115,8 @@ class TestQuantModel(torch.nn.Module):
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -27,11 +27,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
@ -52,6 +48,8 @@ def is_nvfp4_supported():
|
||||
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@ -62,13 +60,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
activation_quant_key=self.quant_key,
|
||||
weight_quant_key=self.quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
|
||||
@ -14,9 +14,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
QUANT_STRATEGY_MAP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
check_aiter_fp8_linear_support,
|
||||
@ -29,7 +26,11 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
process_fp8_weight_tensor_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
@ -48,6 +49,12 @@ strategy_to_parameter_type = {
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
STATIC_QUANT = True
|
||||
DYNAMIC_QUANT = False
|
||||
quant_keys = {
|
||||
STATIC_QUANT: (kFp8StaticTensorSym, kFp8StaticTensorSym),
|
||||
DYNAMIC_QUANT: (kFp8DynamicTokenSym, kFp8StaticTensorSym),
|
||||
}
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -57,22 +64,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
if self.weight_block_size is not None:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
self.act_q_group_shape = (
|
||||
GroupShape.PER_TENSOR
|
||||
if is_static_input_scheme
|
||||
else GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
@ -80,12 +78,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
act_q_static=self.is_static_input_scheme,
|
||||
act_q_group_shape=self.act_q_group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
out_dtype=self.out_dtype,
|
||||
activation_quant_key, weight_quant_key = quant_keys[is_static_input_scheme]
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=activation_quant_key,
|
||||
weight_quant_key=weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -204,4 +201,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -21,16 +21,13 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
@ -97,12 +94,10 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
act_q_static=False,
|
||||
act_q_group_shape=GroupShape.PER_TOKEN,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
out_dtype=self.out_dtype,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8DynamicTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -194,4 +189,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -45,10 +45,6 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
@ -82,6 +78,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
@ -380,8 +379,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
self.activation_quant_key = kFp8StaticTokenSym
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
self.activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
@ -393,11 +394,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
act_q_static=self.act_q_static,
|
||||
act_q_group_shape=self.act_q_group_shape,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=self.out_dtype,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -684,7 +684,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@ -4,43 +4,32 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
|
||||
class ScaledMMLinearQuantStrategy(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
QUANT_STRATEGY_MAP = {
|
||||
QuantizationStrategy.TENSOR: ScaledMMLinearQuantStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL: ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
}
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
is_static_input_scheme: bool
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
is_static_input_scheme: bool
|
||||
is_channelwise: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||
activation_group_shape: GroupShape
|
||||
weight_quant_key: QuantKey
|
||||
activation_quant_key: QuantKey
|
||||
out_dtype: torch.dtype | None
|
||||
|
||||
|
||||
@ -103,9 +92,10 @@ class FP8ScaledMMLinearKernel(
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
act_scale_descriptor = c.activation_quant_key.scale
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=c.is_static_input_scheme,
|
||||
group_shape=c.activation_group_shape,
|
||||
static=act_scale_descriptor.static,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_ouput_padding(),
|
||||
)
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
@ -32,7 +32,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
TritonScaledMMLinearKernel,
|
||||
@ -40,7 +39,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -137,16 +136,14 @@ def choose_scaled_mm_linear_kernel(
|
||||
|
||||
|
||||
def init_fp8_linear_kernel(
|
||||
act_q_static: bool,
|
||||
act_q_group_shape: GroupShape,
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy,
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
out_dtype: torch.dtype,
|
||||
module_name: str,
|
||||
) -> FP8ScaledMMLinearKernel:
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
is_static_input_scheme=act_q_static,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=act_q_group_shape,
|
||||
weight_quant_key=weight_quant_key,
|
||||
activation_quant_key=activation_quant_key,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
@ -39,10 +38,10 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return (
|
||||
|
||||
@ -10,7 +10,6 @@ from vllm.platforms import current_platform
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
@ -143,10 +142,10 @@ class TorchScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return (
|
||||
@ -183,10 +182,10 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||
return (
|
||||
@ -237,10 +236,10 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return (
|
||||
|
||||
@ -11,7 +11,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from .utils import apply_weights_fp8
|
||||
|
||||
@ -72,18 +71,18 @@ def rocm_per_tensor_float_w8a8_scaled_mm(
|
||||
bias: torch.Tensor,
|
||||
output_shape: list[int],
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A, B, out_dtype, As, Bs, bias
|
||||
)
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
# if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@ -95,10 +94,10 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
# TODO: check if this causes an issue on non-ROCM platforms
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
per_tensor_activation_scales = c.activation_group_shape.is_per_tensor()
|
||||
per_tensor_weight_scales = (
|
||||
c.weight_quant_strategy == ScaledMMLinearQuantStrategy.TENSOR
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
|
||||
@ -40,9 +40,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
@ -68,9 +65,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
cutlass_fp4_supported,
|
||||
is_layer_skipped,
|
||||
kFp8StaticTensorSym,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -260,10 +257,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=None,
|
||||
activation_quant_key=kFp8StaticTensorSym,
|
||||
weight_quant_key=kFp8StaticTensorSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
@ -19,12 +19,9 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTokenSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -103,11 +100,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
)
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
act_q_static=False,
|
||||
act_q_group_shape=GroupShape.PER_TOKEN,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
out_dtype=self.out_dtype,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8DynamicTokenSym,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -135,4 +131,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -11,11 +11,13 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kFp8StaticTokenSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale,
|
||||
@ -31,11 +33,6 @@ __all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANT_STRATEGY_MAP = {
|
||||
"per_tensor": ScaledMMLinearQuantStrategy.TENSOR,
|
||||
"per_channel": ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
}
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
def __init__(
|
||||
@ -48,11 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
|
||||
per_token = (
|
||||
per_token_activation = (
|
||||
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
|
||||
)
|
||||
self.act_quant_group_shape = (
|
||||
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
|
||||
per_token_weight = self.weight_qscheme == "per_channel"
|
||||
|
||||
self.activation_quant_key = (
|
||||
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
|
||||
)
|
||||
self.weight_quant_key = (
|
||||
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@ -103,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
@ -174,12 +176,10 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
layer.register_parameter("input_scale_ub", None)
|
||||
|
||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
act_q_static=self.is_static_input_scheme,
|
||||
act_q_group_shape=self.act_quant_group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
out_dtype=self.out_dtype,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
activation_quant_key=self.activation_quant_key,
|
||||
weight_quant_key=self.weight_quant_key,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
@ -189,4 +189,4 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||
|
||||
@ -109,6 +109,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
|
||||
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
|
||||
|
||||
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
|
||||
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
|
||||
|
||||
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user