mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 05:37:02 +08:00
[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)
Signed-off-by: Luka Govedic <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
35514b682a
commit
31d5c1797f
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
||||
@ -44,7 +44,9 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
use_per_token_if_dynamic=True)
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.sqrt(x)
|
||||
@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||
vllm_config.compilation_config.pass_config = \
|
||||
PassConfig(enable_fusion=True, enable_noop=True)
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
))
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
|
||||
|
||||
@ -4,33 +4,56 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = (torch.rand(
|
||||
hidden_size,
|
||||
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = scaled_fp8_quant(y, self.scale)
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
cutlass_fp8_enabled):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(fusion_pass)
|
||||
model = TestModel()
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = TestModel(hidden_size, cutlass_fp8_enabled)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
||||
@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -289,7 +291,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: tuple[int, int]):
|
||||
group_shape: GroupShape):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
@ -298,7 +300,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
TODO(luka) merge parameters into QuantDescriptor
|
||||
:param dtype: quantized dtype
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quant group shape. (-1, -1) for per-tensor.
|
||||
:param group_shape: quant group shape.
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
@ -19,6 +19,8 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
|
||||
@ -598,10 +600,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
head_dim))
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: tuple[int, int]):
|
||||
group_shape: GroupShape):
|
||||
if self.use_triton_flash_attn:
|
||||
return dtype == current_platform.fp8_dtype(
|
||||
) and static and group_shape == (-1, -1) # per-tensor
|
||||
) and static and group_shape == GroupShape.PER_TENSOR
|
||||
|
||||
# Only supported in the Triton backend
|
||||
return False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, ClassVar, NamedTuple, Optional
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@ -11,6 +11,8 @@ from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import find_getitem_maybe
|
||||
@ -33,27 +35,6 @@ RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
class QuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of quantization.
|
||||
|
||||
@ -111,6 +111,8 @@ def _fp8_quantize(
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
# TODO(luka): use QuantFP8 custom op
|
||||
# https://github.com/vllm-project/vllm/issues/20711
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||
else:
|
||||
|
||||
@ -15,6 +15,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, sparse_cutlass_supported)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
@ -24,6 +27,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
|
||||
@ -45,6 +50,12 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
and self.model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value)
|
||||
|
||||
if quantized and input_quant is not None and \
|
||||
self._get_quant_dtype() == current_platform.fp8_dtype():
|
||||
static = not input_quant.dynamic
|
||||
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.quant_fp8 = QuantFP8(static, g_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
@ -232,9 +243,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = None
|
||||
if hasattr(layer, "input_scale"):
|
||||
scale = layer.input_scale
|
||||
scale = getattr(layer, 'input_scale', None)
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
@ -242,11 +251,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
if scale is not None:
|
||||
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
|
||||
else:
|
||||
q_input, input_scale = ops.scaled_fp8_quant(
|
||||
x, use_per_token_if_dynamic=True)
|
||||
q_input, input_scale = self.quant_fp8(x, scale=scale)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
@ -269,7 +274,10 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
return self._get_quant_dtype()
|
||||
|
||||
def _get_quant_dtype(self) -> torch.dtype:
|
||||
assert self.quantized
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@ from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale)
|
||||
@ -26,7 +28,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.strategy = strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
if is_static_input_scheme else GroupShape.PER_TOKEN
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
|
||||
@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
@ -37,7 +37,6 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = not current_platform.has_device_capability(89)
|
||||
self.fp8_linear = Fp8LinearOp()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
@ -76,7 +75,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
def create_weights(
|
||||
|
||||
@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported,
|
||||
cutlass_fp8_supported, maybe_create_device_identity,
|
||||
@ -202,9 +202,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
and current_platform.is_fp8_fnuz())
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
# Default to using per_token quantization if cutlass is supported
|
||||
use_per_token_if_dynamic=cutlass_fp8_supported())
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_fp8_supported=cutlass_fp8_supported())
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
|
||||
103
vllm/model_executor/layers/quantization/input_quant_fp8.py
Normal file
103
vllm/model_executor/layers/quantization/input_quant_fp8.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
|
||||
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
|
||||
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
|
||||
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||
|
||||
|
||||
@CustomOp.register("quant_fp8")
|
||||
class QuantFP8(CustomOp):
|
||||
"""
|
||||
Quantize input tensor to per-tensor or per-token FP8.
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None):
|
||||
"""
|
||||
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR)
|
||||
:param num_token_padding: Pad the token dimension of output to this size
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_token_padding = num_token_padding
|
||||
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
|
||||
assert not static or group_shape == GroupShape.PER_TENSOR, \
|
||||
"Only per-tensor scales supported for static quantization."
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (not self.static and self.group_shape
|
||||
== GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1)
|
||||
|
||||
return ops.scaled_fp8_quant(
|
||||
x,
|
||||
scale,
|
||||
num_token_padding=self.num_token_padding,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (not self.static and self.group_shape
|
||||
== GroupShape.PER_TOKEN
|
||||
and scale_ub.numel() == 1)
|
||||
|
||||
if scale is None:
|
||||
if self.group_shape == GroupShape.PER_TOKEN:
|
||||
x_max, _ = x.abs().max(dim=-1)
|
||||
x_max = x_max.unsqueeze(-1).to(torch.float32)
|
||||
if scale_ub is not None:
|
||||
x_max = x_max.clamp(max=scale_ub)
|
||||
else:
|
||||
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
|
||||
|
||||
scale = x_max / _FP8_MAX
|
||||
scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
|
||||
# Even for dynamic per-token scales,
|
||||
# reciprocal performs slightly better than division
|
||||
out = x.to(torch.float32) * scale.reciprocal()
|
||||
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
# This currently generates an extra Triton kernel in compilation.
|
||||
# Fortunately, we don't use padding if compiling.
|
||||
# TODO(luka): benchmark torch._scaled_mm to hopefully remove padding
|
||||
# in general.
|
||||
if self.num_token_padding is not None:
|
||||
padding = max(self.num_token_padding - out.size(0), 0)
|
||||
out = F.pad(out, (0, 0, 0, padding), "constant", 0.0)
|
||||
|
||||
return out, scale
|
||||
@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
@ -102,7 +102,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
GroupShape, is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
@ -95,8 +95,10 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False,
|
||||
use_per_token_if_dynamic=True)
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
cutlass_fp8_supported=False,
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
@ -28,10 +30,14 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
self.is_static_input_scheme = not cast(
|
||||
bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
|
||||
and self.input_qscheme == "per_channel")
|
||||
|
||||
per_token = (not self.is_static_input_scheme
|
||||
and self.input_qscheme == "per_channel")
|
||||
self.act_quant_group_shape = GroupShape.PER_TOKEN \
|
||||
if per_token else GroupShape.PER_TENSOR
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_quant_group_shape)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
@ -44,7 +50,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
if current_platform.is_rocm():
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
@ -82,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.use_per_token_if_dynamic:
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Optional
|
||||
from typing import ClassVar, NamedTuple, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -12,13 +12,30 @@ from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
@ -58,7 +75,7 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: tuple[int, int],
|
||||
group_shape: GroupShape,
|
||||
quant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
@ -99,7 +116,7 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
group_shape: Optional[GroupShape] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
@ -332,6 +349,10 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
|
||||
@ -8,6 +8,9 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
@ -271,20 +274,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
cutlass_fp8_supported: bool, per_tensor_weights: bool,
|
||||
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if cutlass_fp8_supported:
|
||||
return cutlass_w8a8_scaled_mm
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if current_platform.is_rocm():
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
if (use_per_token_if_dynamic and not per_tensor_weights
|
||||
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if not per_tensor_weights and not per_tensor_activations \
|
||||
and USE_ROWWISE_TORCH_SCALED_MM:
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
# Normally, torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
@ -299,11 +303,11 @@ class Fp8LinearOp:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
act_quant_static: bool,
|
||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
||||
pad_output: Optional[bool] = None):
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported
|
||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
@ -312,9 +316,16 @@ class Fp8LinearOp:
|
||||
# as it breaks with dynamic shapes.
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = config.level < CompilationLevel.PIECEWISE
|
||||
self.output_padding = 17 if (
|
||||
pad_output and not current_platform.is_rocm()) else None
|
||||
pad_output = config.level < CompilationLevel.PIECEWISE and \
|
||||
not cutlass_fp8_supported and \
|
||||
not current_platform.is_rocm()
|
||||
|
||||
self.output_padding = 17 if pad_output else None
|
||||
self.act_quant_static = act_quant_static
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.quant_fp8 = QuantFP8(static=act_quant_static,
|
||||
group_shape=act_quant_group_shape,
|
||||
num_token_padding=self.output_padding)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -325,8 +336,6 @@ class Fp8LinearOp:
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
input_scale_ub: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
# TODO(luka) remove this parameter in favor of __init__
|
||||
use_per_token_if_dynamic: Optional[bool] = None
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
@ -336,40 +345,27 @@ class Fp8LinearOp:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
# TODO(luka) this is here because currently MLA only decides this
|
||||
# during the forward method instead of in __init__.
|
||||
if use_per_token_if_dynamic is None:
|
||||
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = input.dtype
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if self.cutlass_fp8_supported:
|
||||
assert input.dtype != current_platform.fp8_dtype(
|
||||
), "FP8 input to cutlass is not currently implemented"
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
qinput, x_scale = self.quant_fp8(
|
||||
input_2d,
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
input_scale_ub,
|
||||
)
|
||||
else:
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.cutlass_fp8_supported, per_tensor_weights,
|
||||
per_tensor_activations, use_per_token_if_dynamic)
|
||||
per_tensor_activations)
|
||||
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user