[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Luka Govedič 2025-07-11 00:56:28 -04:00 committed by GitHub
parent 35514b682a
commit 31d5c1797f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 368 additions and 104 deletions

View File

@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable
import torch
from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
# TODO(luka): use standalone_compile utility
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
def inner(*args):
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
return fn(*args)
return inner
torch._dynamo.config.recompile_limit = 8888
compilation_config = CompilationConfig(custom_ops=["none"])
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)
# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
def cuda_per_token_quant_fp8(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between Triton and CUDA implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
torch_out, torch_scale = torch_per_token_quant_fp8(x)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
if torch.allclose(
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda"],
line_names=["Torch", "CUDA"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)

View File

@ -44,7 +44,9 @@ class TestModel(torch.nn.Module):
]
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
use_per_token_if_dynamic=True)
act_quant_static=static,
act_quant_group_shape=group_shape,
)
def forward(self, x):
resid = torch.sqrt(x)
@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
PassConfig(enable_fusion=True, enable_noop=True)
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
))
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)

View File

@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)

View File

@ -4,33 +4,56 @@ import pytest
import torch
import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
from vllm.platforms import current_platform
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs):
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = (torch.rand(
hidden_size,
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
def forward(self, x):
y = self.silu_and_mul(x)
x2 = scaled_fp8_quant(y, self.scale)
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
return x2
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
cutlass_fp8_enabled):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)
model = TestModel()
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel(hidden_size, cutlass_fp8_enabled)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
x = torch.rand(num_tokens, hidden_size * 2)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)

View File

@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
@ -289,7 +291,7 @@ class AttentionImpl(ABC, Generic[T]):
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
group_shape: GroupShape):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
@ -298,7 +300,7 @@ class AttentionImpl(ABC, Generic[T]):
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:param group_shape: quant group shape.
:return: is fusion supported for this type of quantization
"""
return False

View File

@ -19,6 +19,8 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
@ -598,10 +600,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
group_shape: GroupShape):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
) and static and group_shape == GroupShape.PER_TENSOR
# Only supported in the Triton backend
return False

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional
import torch
import torch._inductor.pattern_matcher as pm
@ -11,6 +11,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.

View File

@ -111,6 +111,8 @@ def _fp8_quantize(
is provided, the output will be blocked.
"""
if block_shape is None:
# TODO(luka): use QuantFP8 custom op
# https://github.com/vllm-project/vllm/issues/20711
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else:

View File

@ -15,6 +15,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
@ -24,6 +27,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
__all__ = ["CompressedTensors24"]
from vllm.platforms import current_platform
class CompressedTensors24(CompressedTensorsScheme):
@ -45,6 +50,12 @@ class CompressedTensors24(CompressedTensorsScheme):
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)
if quantized and input_quant is not None and \
self._get_quant_dtype() == current_platform.fp8_dtype():
static = not input_quant.dynamic
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.quant_fp8 = QuantFP8(static, g_shape)
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
@ -232,9 +243,7 @@ class CompressedTensors24(CompressedTensorsScheme):
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale
scale = getattr(layer, 'input_scale', None)
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
@ -242,11 +251,7 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
q_input, input_scale = self.quant_fp8(x, scale=scale)
else:
# Not quantized, nothing to do with the input_scales, use as is
@ -269,7 +274,10 @@ class CompressedTensors24(CompressedTensorsScheme):
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
return self._get_quant_dtype()
def _get_quant_dtype(self) -> torch.dtype:
assert self.quantized
assert self.weight_quant is not None
assert self.input_quant is not None

View File

@ -9,6 +9,8 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
@ -26,7 +28,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
@classmethod
def get_min_capability(cls) -> int:

View File

@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
@ -37,7 +37,6 @@ class FBGEMMFp8Config(QuantizationConfig):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = not current_platform.has_device_capability(89)
self.fp8_linear = Fp8LinearOp()
@classmethod
def get_name(cls) -> QuantizationMethods:
@ -76,7 +75,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
self.out_dtype = torch.get_default_dtype()
def create_weights(

View File

@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
cutlass_fp8_supported, maybe_create_device_identity,
@ -202,9 +202,17 @@ class Fp8LinearMethod(LinearMethodBase):
and current_platform.is_fp8_fnuz())
self.block_quant = self.quant_config.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
self.fp8_linear = Fp8LinearOp(
# Default to using per_token quantization if cutlass is supported
use_per_token_if_dynamic=cutlass_fp8_supported())
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
cutlass_fp8_supported=cutlass_fp8_supported())
def create_weights(
self,

View File

@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to per-tensor or per-token FP8.
This CustomOp supports both static and dynamic quantization.
"""
def __init__(self,
static: bool,
group_shape: GroupShape,
num_token_padding: Optional[int] = None):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR)
:param num_token_padding: Pad the token dimension of output to this size
"""
super().__init__()
self.num_token_padding = num_token_padding
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, \
"Only per-tensor scales supported for static quantization."
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
def forward_cuda(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)
return ops.scaled_fp8_quant(
x,
scale,
num_token_padding=self.num_token_padding,
scale_ub=scale_ub,
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
def forward_native(
self,
x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
):
assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape
== GroupShape.PER_TOKEN
and scale_ub.numel() == 1)
if scale is None:
if self.group_shape == GroupShape.PER_TOKEN:
x_max, _ = x.abs().max(dim=-1)
x_max = x_max.unsqueeze(-1).to(torch.float32)
if scale_ub is not None:
x_max = x_max.clamp(max=scale_ub)
else:
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = x_max / _FP8_MAX
scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
out = x.to(torch.float32) * scale.reciprocal()
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
# This currently generates an extra Triton kernel in compilation.
# Fortunately, we don't use padding if compiling.
# TODO(luka): benchmark torch._scaled_mm to hopefully remove padding
# in general.
if self.num_token_padding is not None:
padding = max(self.num_token_padding - out.size(0), 0)
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
return out, scale

View File

@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
@ -102,7 +102,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
def create_weights(
self,

View File

@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
GroupShape, is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
@ -95,8 +95,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
super().__init__(quant_config=quant_config)
# Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False,
use_per_token_if_dynamic=True)
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
cutlass_fp8_supported=False,
act_quant_group_shape=GroupShape.PER_TOKEN)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,

View File

@ -7,6 +7,8 @@ import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
@ -28,10 +30,14 @@ class QuarkW8A8Fp8(QuarkScheme):
self.is_static_input_scheme = not cast(
bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
and self.input_qscheme == "per_channel")
per_token = (not self.is_static_input_scheme
and self.input_qscheme == "per_channel")
self.act_quant_group_shape = GroupShape.PER_TOKEN \
if per_token else GroupShape.PER_TENSOR
self.fp8_linear = Fp8LinearOp(
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape)
self.out_dtype = torch.get_default_dtype()
@classmethod
@ -44,7 +50,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.weight_qscheme == "per_tensor":
if current_platform.is_rocm():
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
@ -82,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
if self.use_per_token_if_dynamic:
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter

View File

@ -3,7 +3,7 @@
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Optional
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
@ -58,7 +75,7 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: tuple[int, int],
group_shape: GroupShape,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
@ -99,7 +116,7 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[tuple[int, int]] = None,
group_shape: Optional[GroupShape] = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor,
)
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,

View File

@ -8,6 +8,9 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting
@ -271,20 +274,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
) -> Callable[..., torch.Tensor]:
per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
return rocm_per_tensor_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
if (use_per_token_if_dynamic and not per_tensor_weights
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if not per_tensor_weights and not per_tensor_activations \
and USE_ROWWISE_TORCH_SCALED_MM:
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
return torch_channelwise_w8a8_scaled_mm
@ -299,11 +303,11 @@ class Fp8LinearOp:
"""
def __init__(self,
act_quant_static: bool,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
use_per_token_if_dynamic: bool = False,
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None):
self.cutlass_fp8_supported = cutlass_fp8_supported
self.use_per_token_if_dynamic = use_per_token_if_dynamic
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
@ -312,9 +316,16 @@ class Fp8LinearOp:
# as it breaks with dynamic shapes.
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE
self.output_padding = 17 if (
pad_output and not current_platform.is_rocm()) else None
pad_output = config.level < CompilationLevel.PIECEWISE and \
not cutlass_fp8_supported and \
not current_platform.is_rocm()
self.output_padding = 17 if pad_output else None
self.act_quant_static = act_quant_static
self.act_quant_group_shape = act_quant_group_shape
self.quant_fp8 = QuantFP8(static=act_quant_static,
group_shape=act_quant_group_shape,
num_token_padding=self.output_padding)
def apply(
self,
@ -325,8 +336,6 @@ class Fp8LinearOp:
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
# TODO(luka) remove this parameter in favor of __init__
use_per_token_if_dynamic: Optional[bool] = None
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
@ -336,40 +345,27 @@ class Fp8LinearOp:
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
# TODO(luka) this is here because currently MLA only decides this
# during the forward method instead of in __init__.
if use_per_token_if_dynamic is None:
use_per_token_if_dynamic = self.use_per_token_if_dynamic
if out_dtype is None:
out_dtype = input.dtype
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if self.cutlass_fp8_supported:
assert input.dtype != current_platform.fp8_dtype(
), "FP8 input to cutlass is not currently implemented"
qinput, x_scale = ops.scaled_fp8_quant(
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)
input_scale_ub,
)
else:
if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic)
else:
qinput, x_scale = input_2d, input_scale
qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations, use_per_token_if_dynamic)
per_tensor_activations)
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,