Revert "[Performance] Move apply_w8a8_block_fp8_linear to an op class… (#25607)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith 2025-09-25 04:05:21 -04:00 committed by GitHub
parent af4ee63e0e
commit 1260180c67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 205 additions and 346 deletions

View File

@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm, w8a8_block_fp8_matmul,
) )
from vllm.utils import FlexibleArgumentParser, cdiv from vllm.utils import FlexibleArgumentParser, cdiv
@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
), ),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
), ),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(

View File

@ -9,7 +9,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm, w8a8_block_fp8_matmul,
) )
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
@ -63,7 +63,7 @@ def benchmark_shape(m: int,
# === vLLM Triton Implementation === # === vLLM Triton Implementation ===
def vllm_triton_gemm(): def vllm_triton_gemm():
return w8a8_triton_block_scaled_mm(A_vllm, return w8a8_block_fp8_matmul(A_vllm,
B_vllm, B_vllm,
A_scale_vllm, A_scale_vllm,
B_scale_vllm, B_scale_vllm,

View File

@ -11,7 +11,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul) native_w8a8_block_matmul)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt, from vllm.utils.deep_gemm import (fp8_gemm_nt,
@ -91,8 +91,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype) out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out_dtype)
rel_diff = (torch.mean( rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

View File

@ -20,11 +20,9 @@ from vllm.platforms import current_platform
(8, 513, 64), # Non-divisible (native only) (8, 513, 64), # Non-divisible (native only)
]) ])
@pytest.mark.parametrize("seed", [42]) @pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int, group_size: int, seed: int) -> None:
use_ue8m0: bool) -> None:
"""Test QuantFP8 group quantization with various configurations. """Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales, Tests both CUDA and native implementations, column-major scales,
@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size) group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False, quant_op = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=False, column_major_scales=False)
use_ue8m0=use_ue8m0)
# 1. Test native implementation (always available) # 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone()) x_quant_native, scales_native = quant_op.forward_native(x.clone())
@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration # 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False, quant_op_col = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=True, column_major_scales=True)
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x.clone()) _, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (batch_size, expected_num_groups) assert scales_col.shape == (expected_num_groups, batch_size)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size
# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
# 3. Test CUDA implementation (only for divisible dimensions) # 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible: if is_divisible:
@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
@pytest.mark.parametrize("seed", [42]) @pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: def test_quantfp8_group_multidimensional(seed: int) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
group_size = 64 group_size = 64
@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
group_shape = GroupShape(1, group_size) group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False, quant_op = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=False, column_major_scales=False)
use_ue8m0=use_ue8m0)
x_quant, scales = quant_op.forward_native(x_3d.clone()) x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape assert x_quant.shape == x_3d.shape
@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
# Test column_major_scales with multi-dim # Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False, quant_op_col = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=True, column_major_scales=True)
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x_3d.clone()) _, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

View File

@ -17,6 +17,8 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.layernorm import (RMSNorm, from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func, dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm) fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@ -109,6 +111,34 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
@pytest.mark.skipif(
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_gemm_w8a8_blockscale: str,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
use_rocm_aiter_gemm_w8a8_blockscale)
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
int(use_rocm_aiter_gemm_w8a8_blockscale)))
block_scale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_gemm_w8a8_blockscale):
assert block_scale_func == (
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
else:
assert block_scale_func == w8a8_block_fp8_matmul
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

View File

@ -18,9 +18,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported) cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -745,35 +742,3 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
perplexity = llm.generate_prompt_perplexity([prompt])[0] perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity) print(perplexity)
assert perplexity <= exp_perplexity assert perplexity <= exp_perplexity
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
with vllm_runner(model_path) as llm:
fp8_dtype = current_platform.fp8_dtype()
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp)
assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2
input_quant_op = \
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method == input_quant_op.forward_cuda
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -545,23 +545,6 @@ class VllmConfig:
# local attention. # local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True self.scheduler_config.disable_hybrid_kv_cache_manager = True
def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
def update_sizes_for_sequence_parallelism(self, def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list: possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when

View File

@ -644,14 +644,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# If no matches, return None # If no matches, return None
return None return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
if (weight_quant is not None
and weight_quant.strategy == QuantizationStrategy.BLOCK):
return True
return False
@staticmethod @staticmethod
def supports_cutlass_24( def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs], weight_quant: Optional[QuantizationArgs],

View File

@ -11,7 +11,7 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
@ -41,30 +41,16 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
self.weight_block_size = self.weight_quant.block_structure self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# lovelace and up # lovelace and up
@ -155,14 +141,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.weight_block_size is not None: if layer.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply( return apply_fp8_block_linear(
layer,
input=x, input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias, bias=bias,
) cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,

View File

@ -42,7 +42,7 @@ def prepare_block_fp8_matmul_inputs(
return M, N, K, C return M, N, K, C
def w8a8_deepgemm_block_scaled_mm( def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
@ -58,7 +58,7 @@ def w8a8_deepgemm_block_scaled_mm(
return C return C
def w8a8_deepgemm_block_scaled_mm_fake( def w8a8_block_fp8_matmul_deepgemm_fake(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
@ -72,7 +72,7 @@ def w8a8_deepgemm_block_scaled_mm_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm", op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_deepgemm_block_scaled_mm, op_func=w8a8_block_fp8_matmul_deepgemm,
fake_impl=w8a8_deepgemm_block_scaled_mm_fake, fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
) )

View File

@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, swap_w13_to_w31) select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, expert_weight_is_col_major, create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
@ -242,28 +242,15 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size = self.quant_config.weight_block_size self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static" self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size: # Use per-token quantization for better perf if dynamic and cutlass
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else: else:
# Use per-token quantization for better perf if dynamic and cutlass self.act_q_group_shape = GroupShape.PER_TENSOR
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant: self.fp8_linear = Fp8LinearOp(
assert not self.act_q_static act_quant_static=self.act_q_static,
assert self.weight_block_size is not None act_quant_group_shape=self.act_q_group_shape)
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
def create_weights( def create_weights(
self, self,
@ -412,15 +399,12 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias) bias=bias)
if self.block_quant: if self.block_quant:
assert self.weight_block_size is not None return apply_fp8_block_linear(
layer,
return self.w8a8_block_fp8_linear.apply(
input=x, input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias, bias=bias,
) cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,

View File

@ -27,14 +27,11 @@ class QuantFP8(CustomOp):
This CustomOp supports both static and dynamic quantization. This CustomOp supports both static and dynamic quantization.
""" """
def __init__( def __init__(self,
self, static: bool,
static: bool, group_shape: GroupShape,
group_shape: GroupShape, num_token_padding: Optional[int] = None,
num_token_padding: Optional[int] = None, column_major_scales: bool = False):
column_major_scales: bool = False,
use_ue8m0: Optional[bool] = None, # for Torch compile
):
""" """
:param static: static or dynamic quantization :param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
@ -49,7 +46,6 @@ class QuantFP8(CustomOp):
self.group_shape = group_shape self.group_shape = group_shape
self.num_token_padding = num_token_padding self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales self.column_major_scales = column_major_scales
self.use_ue8m0 = use_ue8m0
self.is_group_quant = group_shape.is_per_group() self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant: if self.is_group_quant:
@ -74,8 +70,7 @@ class QuantFP8(CustomOp):
x, x,
group_size=self.group_size, group_size=self.group_size,
column_major_scales=self.column_major_scales, column_major_scales=self.column_major_scales,
dtype=_FP8_DTYPE, dtype=_FP8_DTYPE)
use_ue8m0=self.use_ue8m0)
assert (scale is not None) == self.static assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape assert scale_ub is None or (not self.static and self.group_shape
@ -142,10 +137,7 @@ class QuantFP8(CustomOp):
x_grouped = x.view(-1, num_groups, self.group_size) x_grouped = x.view(-1, num_groups, self.group_size)
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
scales_raw = absmax / _FP8_MAX scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
if self.use_ue8m0:
scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw)))
scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR)
x_scaled = x_grouped / scales x_scaled = x_grouped / scales
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
@ -159,6 +151,6 @@ class QuantFP8(CustomOp):
scales = scales.reshape(orig_shape[:-1] + (num_groups, )) scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
if self.column_major_scales: if self.column_major_scales:
scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) scales = scales.transpose(-2, -1).contiguous()
return x_quant, scales return x_quant, scales

View File

@ -13,9 +13,8 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, group_broadcast) group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
@ -25,7 +24,6 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear) should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -37,8 +35,6 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm( def cutlass_scaled_mm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -46,17 +42,15 @@ def cutlass_scaled_mm(
Bs: torch.Tensor, Bs: torch.Tensor,
block_size: list[int], block_size: list[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
is_hopper: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm( return ops.cutlass_scaled_mm(
A, A,
B.T, B.T,
out_dtype=output_dtype, out_dtype=output_dtype,
scale_a=As, scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time # SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None and is_hopper else Bs.T) scale_b=Bs if block_size is not None
and current_platform.is_device_capability(90) else Bs.T)
def rocm_aiter_gemm_w8a8_blockscale_impl( def rocm_aiter_gemm_w8a8_blockscale_impl(
@ -102,190 +96,122 @@ if current_platform.is_rocm():
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
# TODO we should be able to change the type of block_size to GroupShape def dispatch_w8a8_blockscale_func(
# after we resolve GroupShape compilation issue use_cutlass: bool, use_aiter_and_is_supported: bool
# https://github.com/vllm-project/vllm/issues/25270 ) -> Callable[[
def _w8a8_triton_block_scaled_mm_func( torch.Tensor,
qx: torch.Tensor, torch.Tensor,
weight: torch.Tensor, torch.Tensor,
x_scale: torch.Tensor, torch.Tensor,
weight_scale: torch.Tensor, list[int],
block_size: list[int], torch.dtype,
output_dtype: torch.dtype, ], torch.Tensor]:
) -> torch.Tensor: if use_cutlass:
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, return cutlass_scaled_mm
block_size, output_dtype) if (use_aiter_and_is_supported):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
return w8a8_block_fp8_matmul
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
# Note: the check can be removed when CPU torch > 2.7
if not current_platform.is_cpu():
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
# TODO fix ROCm->Triton custom path: # TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp: def apply_w8a8_block_fp8_linear(
""" input: torch.Tensor,
This class executes a Blocked FP8 linear layer using cutlass if supported weight: torch.Tensor,
and torch.scaled_mm otherwise. block_size: list[int],
""" weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
def __init__( if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = \
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
self.deepgemm_input_quant_op = (QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
else None)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm_for_fp8_linear(output_dtype, weight, q_input, x_scale = per_token_group_quant_fp8(
self.is_deep_gemm_supported): input_2d,
output = self._run_deepgemm(input, weight, weight_scale) block_size[1],
if bias is not None: column_major_scales=True,
output = output + bias )
return output.to(dtype=input.dtype).view(*output_shape)
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
# ensure DeepGEMM-backed custom op is registered before use # ensure DeepGEMM-backed custom op is registered before use
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
assert self.deepgemm_input_quant_op is not None output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
q_input, q_input,
weight, weight,
x_scale, x_scale,
weight_scale, weight_scale,
self.weight_group_shape, block_size,
output_dtype=input_2d.dtype) output_dtype=output_dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
def _run_cutlass( w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
self, cutlass_block_fp8_supported, use_aiter_and_is_supported)
input_2d: torch.Tensor, if cutlass_block_fp8_supported:
weight: torch.Tensor, num_pad = 0
weight_scale: torch.Tensor, if current_platform.is_device_capability(90):
) -> torch.Tensor: # pad first dimension to be divisible by 4 due to
assert self.input_quant_op is not None # cutlass blockwise gemm limitation for hopper
if self.is_hopper: num_pad = 4 - (input_2d.shape[0] % 4)
# We pad unconditionally (even if shape is already divisible by 4) if num_pad > 0:
# to support dynamic shape for input_2d.shape[0] in torch.compile input_2d = torch.nn.functional.pad(input_2d,
x = torch.nn.functional.pad(input_2d, (0, 0, 0, num_pad),
(0, 0, 0, -input_2d.shape[0] % 4)) "constant", 0)
else: q_input, x_scale = per_token_group_quant_fp8(input_2d,
x = input_2d block_size[1],
column_major_scales=True)
q_input, x_scale = self.input_quant_op(x) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, block_size, input.dtype)
list(self.weight_group_shape), if num_pad > 0:
input_2d.dtype, self.is_hopper) output = output[:-num_pad]
output = output[0:input_2d.shape[0], ...] else:
return output
def _run_aiter(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
q_input, x_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[Callable[[
torch.Tensor,
torch.Tensor,
torch.Tensor,
], torch.Tensor], Optional[QuantFP8]]:
if use_cutlass:
return self._run_cutlass, (QuantFP8(False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False))
if use_aiter_and_is_supported: if use_aiter_and_is_supported:
return self._run_aiter, None q_input, x_scale = aiter_per1x128_quant(
return self._run_triton, (QuantFP8(False, input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
self.act_quant_group_shape, else:
column_major_scales=False, q_input, x_scale = per_token_group_quant_fp8(
use_ue8m0=False)) input_2d, block_size[1], column_major_scales=False)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def apply_w8a8_block_fp8_linear_fake(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
output_shape = [*input.shape[:-1], weight.shape[0]]
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
if not current_platform.is_cpu():
direct_register_custom_op(
op_name="apply_w8a8_block_fp8_linear",
op_func=apply_w8a8_block_fp8_linear,
mutates_args=[],
fake_impl=apply_w8a8_block_fp8_linear_fake,
)
def input_to_float8( def input_to_float8(
@ -537,7 +463,7 @@ def per_token_group_quant_fp8(
@triton.jit @triton.jit
def _w8a8_triton_block_scaled_mm( def _w8a8_block_fp8_matmul(
# Pointers to inputs and output # Pointers to inputs and output
A, A,
B, B,
@ -662,7 +588,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None return None
def w8a8_triton_block_scaled_mm( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
@ -722,7 +648,7 @@ def w8a8_triton_block_scaled_mm(
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), ) triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_triton_block_scaled_mm[grid]( _w8a8_block_fp8_matmul[grid](
A, A,
B, B,
C, C,
@ -1005,6 +931,25 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
layer.weight_scale.data.T.contiguous(), requires_grad=False) layer.weight_scale.data.T.contiguous(), requires_grad=False)
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
bias: Optional[torch.Tensor],
cutlass_block_fp8_supported: bool,
use_aiter_and_is_supported: bool) -> torch.Tensor:
"""Apply block-wise FP8 linear operation."""
assert layer.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=input,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: def expert_weight_is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3 assert x.dim() == 3
b, m, n = x.shape b, m, n = x.shape

View File

@ -9,7 +9,7 @@ from __future__ import annotations
import functools import functools
import importlib import importlib
import os import os
from typing import Any, Callable, NoReturn, Optional from typing import Any, Callable, NoReturn
import torch import torch
@ -184,13 +184,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
return 1 - sim return 1 - sim
def should_use_deepgemm_for_fp8_linear( def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
output_dtype: torch.dtype, weight: torch.Tensor):
weight: torch.Tensor, return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
supports_deep_gemm: Optional[bool] = None):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (supports_deep_gemm and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)