mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 15:34:29 +08:00
enable DeepGEMM swapAB from FlashInfer for M<32 linear gemms
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
parent
3d429d63a6
commit
5a5506c661
@ -3,10 +3,8 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Disable DeepGEMM for this benchmark
|
# Disable DeepGEMM for this benchmark to use CUTLASS
|
||||||
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||||
# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8")
|
|
||||||
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -96,15 +94,6 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
|
|||||||
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||||
available_providers.append("w8a8-block-fp8-cutlass")
|
available_providers.append("w8a8-block-fp8-cutlass")
|
||||||
|
|
||||||
# Check if FlashInfer block GEMM is available
|
|
||||||
try:
|
|
||||||
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
|
||||||
|
|
||||||
if has_flashinfer_block_gemm():
|
|
||||||
available_providers.append("flashinfer-block-fp8")
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@vllm_triton.testing.perf_report(
|
@vllm_triton.testing.perf_report(
|
||||||
vllm_triton.testing.Benchmark(
|
vllm_triton.testing.Benchmark(
|
||||||
@ -145,14 +134,6 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
|
|||||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
||||||
)
|
)
|
||||||
elif provider == "flashinfer-block-fp8":
|
|
||||||
# Use the same W8A8 setup as other providers for fair comparison
|
|
||||||
run_w8a8_flashinfer = build_w8a8_block_fp8_runner(
|
|
||||||
M, N, K, block_size, device, use_cutlass=False
|
|
||||||
)
|
|
||||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
|
||||||
lambda: run_w8a8_flashinfer(), quantiles=quantiles
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,10 @@ from vllm.utils.deep_gemm import (
|
|||||||
per_block_cast_to_fp8,
|
per_block_cast_to_fp8,
|
||||||
should_use_deepgemm_for_fp8_linear,
|
should_use_deepgemm_for_fp8_linear,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.flashinfer import (
|
||||||
|
flashinfer_fp8_blockscale_gemm,
|
||||||
|
has_flashinfer_fp8_blockscale_gemm,
|
||||||
|
)
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
@ -212,237 +216,46 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||||
)
|
)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
"""
|
if not has_flashinfer_fp8_blockscale_gemm():
|
||||||
Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp.
|
|
||||||
|
|
||||||
This tests the FP8 + FP8 → BF16 path (W8A8 full quantization).
|
|
||||||
Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
|
||||||
|
|
||||||
if not has_flashinfer_block_gemm():
|
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
||||||
)
|
)
|
||||||
|
# only aligned sizes
|
||||||
# Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer
|
if K % 128 != 0 or N % 64 != 0:
|
||||||
# These cause CUDA runtime errors
|
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||||
if K == 3884 or N == 7748:
|
|
||||||
pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}")
|
|
||||||
|
|
||||||
# Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer)
|
|
||||||
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
|
|
||||||
# Reload envs module to pick up the env var change
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
|
|
||||||
importlib.reload(envs)
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
||||||
W8A8BlockFp8LinearOp,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
||||||
GroupShape,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max = fp8_info.max
|
||||||
|
|
||||||
# Create BF16 inputs (normalized like TRT-LLM)
|
A_bf16 = (torch.rand(M, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||||
A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K
|
B_bf16 = (torch.rand(N, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||||
B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K
|
|
||||||
|
|
||||||
# Quantize weight with per-block scales
|
A_fp8, As_fp8 = per_token_group_quant_fp8(A_bf16, block_size[1], use_ue8m0=False)
|
||||||
B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size)
|
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_bf16, block_size, use_ue8m0=False)
|
||||||
|
|
||||||
# Create W8A8BlockFp8LinearOp to handle input quantization
|
As = As_fp8.to(torch.float32)
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
Bs = Bs_fp8.to(torch.float32)
|
||||||
weight_group_shape = GroupShape(block_n, block_k)
|
|
||||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization
|
|
||||||
|
|
||||||
linear_op = W8A8BlockFp8LinearOp(
|
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||||
weight_group_shape=weight_group_shape,
|
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||||
act_quant_group_shape=act_quant_group_shape,
|
|
||||||
cutlass_block_fp8_supported=False, # Disable CUTLASS
|
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||||
use_aiter_and_is_supported=False, # Disable AITER
|
|
||||||
|
assert As_fp8.shape == (M, (K + 127) // 128), (
|
||||||
|
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify FlashInfer backend is selected
|
out = flashinfer_fp8_blockscale_gemm(
|
||||||
assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, (
|
|
||||||
"FlashInfer backend not selected! "
|
|
||||||
"Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute reference: BF16 × BF16 matmul (before quantization)
|
|
||||||
ref_out = torch.matmul(A_bf16, B_bf16.T)
|
|
||||||
|
|
||||||
# Run W8A8 FlashInfer GEMM (input will be quantized internally)
|
|
||||||
out = linear_op.apply(
|
|
||||||
input=A_bf16,
|
input=A_bf16,
|
||||||
weight=B_fp8,
|
weight=B_fp8,
|
||||||
|
input_scale=None,
|
||||||
weight_scale=Bs,
|
weight_scale=Bs,
|
||||||
input_scale=None, # Will quantize dynamically
|
out_dtype=out_dtype,
|
||||||
bias=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compare results using TensorRT-LLM's calc_diff metric
|
rel_diff = torch.mean(
|
||||||
# This measures normalized similarity: sim = 2*<x,y> / (||x||² + ||y||²)
|
torch.abs(out.to(torch.bfloat16) - ref_out.to(torch.bfloat16))
|
||||||
out_fp64 = out.to(torch.float64)
|
) / torch.mean(torch.abs(ref_out.to(torch.bfloat16)))
|
||||||
ref_fp64 = ref_out.to(torch.float64)
|
assert rel_diff < 0.001
|
||||||
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
|
|
||||||
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
|
|
||||||
diff = 1 - sim
|
|
||||||
|
|
||||||
# W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity)
|
|
||||||
assert diff < 0.001, (
|
|
||||||
f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"M,N,K,block_size,seed",
|
|
||||||
[
|
|
||||||
(1, 1024, 4096, [128, 128], 0),
|
|
||||||
(32, 4096, 512, [128, 128], 0),
|
|
||||||
(128, 1024, 4096, [128, 128], 0),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"input_dtype,weight_dtype",
|
|
||||||
[
|
|
||||||
(torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization)
|
|
||||||
(torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only)
|
|
||||||
(torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_flashinfer_block_gemm_dtypes(
|
|
||||||
M, N, K, block_size, input_dtype, weight_dtype, seed
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM.
|
|
||||||
|
|
||||||
Tests:
|
|
||||||
- BF16 + BF16 → BF16: Both inputs BF16, internal quantization
|
|
||||||
- BF16 + FP8 → BF16: Weight-only quantization
|
|
||||||
- FP8 + FP8 → BF16: W8A8 full quantization
|
|
||||||
|
|
||||||
This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests.
|
|
||||||
"""
|
|
||||||
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
|
||||||
|
|
||||||
if not has_flashinfer_block_gemm():
|
|
||||||
pytest.skip(
|
|
||||||
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
|
||||||
)
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
|
|
||||||
flashinfer_block_gemm,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add debug output to verify test execution
|
|
||||||
print(f"\n{'=' * 80}")
|
|
||||||
print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}")
|
|
||||||
print(f"{'=' * 80}")
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
# Create BF16 data for reference (same as FlashInfer tests)
|
|
||||||
input_bf16 = torch.randn(M, K, dtype=torch.bfloat16)
|
|
||||||
weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Quantize input based on dtype
|
|
||||||
if input_dtype == torch.float8_e4m3fn:
|
|
||||||
input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1])
|
|
||||||
else:
|
|
||||||
input_tensor, input_scale = input_bf16, None
|
|
||||||
|
|
||||||
# Quantize weight based on dtype
|
|
||||||
if weight_dtype == torch.float8_e4m3fn:
|
|
||||||
weight_tensor, weight_scale = per_block_cast_to_fp8(
|
|
||||||
weight_bf16, block_size=block_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
weight_tensor, weight_scale = weight_bf16, None
|
|
||||||
|
|
||||||
# Run FlashInfer FP8 block-scale GEMM
|
|
||||||
output = flashinfer_block_gemm(
|
|
||||||
input=input_tensor,
|
|
||||||
weight=weight_tensor,
|
|
||||||
scales_a=input_scale,
|
|
||||||
scales_b=weight_scale,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify output properties
|
|
||||||
assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}"
|
|
||||||
assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}"
|
|
||||||
|
|
||||||
# Compute reference based on dtype combination
|
|
||||||
if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn:
|
|
||||||
# W8A8: Compare against dequantized FP8 reference (tests kernel correctness)
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
|
||||||
|
|
||||||
input_dequant = torch.zeros_like(input_bf16)
|
|
||||||
for i in range(M):
|
|
||||||
for k_tile in range(k_tiles):
|
|
||||||
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
|
|
||||||
input_dequant[i, start:end] = (
|
|
||||||
input_tensor[i, start:end].to(torch.bfloat16)
|
|
||||||
* input_scale[i, k_tile]
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_dequant = torch.zeros_like(weight_bf16)
|
|
||||||
for j in range(N):
|
|
||||||
for k_tile in range(k_tiles):
|
|
||||||
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
|
|
||||||
weight_dequant[j, start:end] = (
|
|
||||||
weight_tensor[j, start:end].to(torch.bfloat16)
|
|
||||||
* weight_scale[j // block_n, k_tile]
|
|
||||||
)
|
|
||||||
|
|
||||||
reference = torch.matmul(input_dequant, weight_dequant.T)
|
|
||||||
|
|
||||||
# W8A8: Use TRT-LLM's calc_diff metric with strict threshold
|
|
||||||
out_fp64 = output.to(torch.float64)
|
|
||||||
ref_fp64 = reference.to(torch.float64)
|
|
||||||
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
|
|
||||||
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
|
|
||||||
diff = 1 - sim
|
|
||||||
|
|
||||||
# W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity)
|
|
||||||
assert diff < 0.001, (
|
|
||||||
f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# BF16+BF16 or BF16+FP8: Compare against original BF16 reference
|
|
||||||
reference = torch.matmul(input_bf16, weight_bf16.T)
|
|
||||||
|
|
||||||
out_fp64 = output.to(torch.float64)
|
|
||||||
ref_fp64 = reference.to(torch.float64)
|
|
||||||
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
|
|
||||||
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
|
|
||||||
diff = 1 - sim
|
|
||||||
|
|
||||||
if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16:
|
|
||||||
# BF16+BF16: Highest accuracy (internal quantization)
|
|
||||||
threshold = 0.001
|
|
||||||
threshold_desc = "0.1%"
|
|
||||||
elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn:
|
|
||||||
# BF16+FP8: Weight-only quantization, higher error
|
|
||||||
threshold = 0.01
|
|
||||||
threshold_desc = "1%"
|
|
||||||
else:
|
|
||||||
# Other combinations
|
|
||||||
threshold = 0.01
|
|
||||||
threshold_desc = "1%"
|
|
||||||
|
|
||||||
assert diff < threshold, (
|
|
||||||
f"Similarity difference {diff:.6f} too high for "
|
|
||||||
f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})"
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,57 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
"""
|
|
||||||
FlashInfer FP8 Block-Scale GEMM wrapper for vLLM.
|
|
||||||
|
|
||||||
This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM
|
|
||||||
implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_block_gemm(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
scales_a: torch.Tensor | None,
|
|
||||||
scales_b: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Wrapper for FlashInfer's FP8 block-scale GEMM.
|
|
||||||
|
|
||||||
Computes: output = (input @ weight.T) with per-block scaling for quantization.
|
|
||||||
|
|
||||||
Supports three modes:
|
|
||||||
1. BF16 + BF16 → BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally)
|
|
||||||
2. BF16 + FP8 → BF16: Weight-only quantization (scales_a=None, scales_b for weight)
|
|
||||||
3. FP8 + FP8 → BF16: W8A8 full quantization (scales_a for input, scales_b for weight)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: Input tensor (M, K) - BF16 or FP8 e4m3
|
|
||||||
weight: Weight tensor (N, K) - BF16 or FP8 e4m3
|
|
||||||
scales_a: Input scales (M, K//block_k) or None - FP32
|
|
||||||
None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8)
|
|
||||||
Provided: input is pre-quantized FP8 (W8A8 mode)
|
|
||||||
scales_b: Weight scales (N//block_n, K//block_k) - FP32
|
|
||||||
out_dtype: Output dtype (typically torch.bfloat16)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
output: Result tensor (M, N) in out_dtype
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- Requires SM90+ GPU (NVIDIA Hopper)
|
|
||||||
- Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer
|
|
||||||
- For M < 32, automatically uses SwapAB kernel optimization
|
|
||||||
"""
|
|
||||||
from flashinfer.gemm import fp8_blockscale_gemm_swapab
|
|
||||||
|
|
||||||
return fp8_blockscale_gemm_swapab(
|
|
||||||
input=input,
|
|
||||||
weight=weight,
|
|
||||||
input_scale=scales_a,
|
|
||||||
weight_scale=scales_b,
|
|
||||||
out_dtype=out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
@ -35,7 +35,11 @@ from vllm.utils.deep_gemm import (
|
|||||||
should_use_deepgemm_for_fp8_linear,
|
should_use_deepgemm_for_fp8_linear,
|
||||||
transform_sf_into_required_layout,
|
transform_sf_into_required_layout,
|
||||||
)
|
)
|
||||||
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
from vllm.utils.flashinfer import (
|
||||||
|
flashinfer_fp8_blockscale_gemm,
|
||||||
|
is_flashinfer_fp8_blockscale_gemm_supported,
|
||||||
|
should_use_flashinfer_for_block_scale_fp8_linear,
|
||||||
|
)
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -227,6 +231,76 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flashinfer_fp8_blockscale_gemm_impl(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
use_deep_gemm_e8m0: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
def use_flashinfer(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return flashinfer_fp8_blockscale_gemm(
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
def use_deepgemm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
q_input, input_scale = per_token_group_quant_fp8(
|
||||||
|
input,
|
||||||
|
group_size=group_size,
|
||||||
|
column_major_scales=True,
|
||||||
|
use_ue8m0=use_deep_gemm_e8m0,
|
||||||
|
)
|
||||||
|
output = torch.empty(
|
||||||
|
(q_input.shape[0], weight.shape[0]),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=q_input.device,
|
||||||
|
)
|
||||||
|
fp8_gemm_nt(
|
||||||
|
(q_input, input_scale),
|
||||||
|
(weight, weight_scale),
|
||||||
|
output,
|
||||||
|
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
condition = input.shape[0] < 32
|
||||||
|
|
||||||
|
# Pass all required variables through operands
|
||||||
|
return torch.cond(
|
||||||
|
condition, use_flashinfer, use_deepgemm, (input, weight, weight_scale)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flashinfer_fp8_blockscale_gemm_fake(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
use_deep_gemm_e8m0: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty(
|
||||||
|
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
"flashinfer_fp8_blockscale_gemm",
|
||||||
|
_flashinfer_fp8_blockscale_gemm_impl,
|
||||||
|
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO fix ROCm->Triton custom path:
|
# TODO fix ROCm->Triton custom path:
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
class W8A8BlockFp8LinearOp:
|
class W8A8BlockFp8LinearOp:
|
||||||
@ -247,6 +321,7 @@ class W8A8BlockFp8LinearOp:
|
|||||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||||
self.is_hopper = current_platform.is_device_capability(90)
|
self.is_hopper = current_platform.is_device_capability(90)
|
||||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||||
|
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
|
||||||
|
|
||||||
# Get the correct blockscale mul and input quant operations.
|
# Get the correct blockscale mul and input quant operations.
|
||||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||||
@ -282,7 +357,12 @@ class W8A8BlockFp8LinearOp:
|
|||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
output_dtype = input.dtype
|
output_dtype = input.dtype
|
||||||
|
|
||||||
if should_use_deepgemm_for_fp8_linear(
|
if should_use_flashinfer_for_block_scale_fp8_linear(
|
||||||
|
self.is_flashinfer_supported, output_dtype, input_2d, weight
|
||||||
|
):
|
||||||
|
output = self._run_flashinfer(input_2d, weight, weight_scale)
|
||||||
|
|
||||||
|
elif should_use_deepgemm_for_fp8_linear(
|
||||||
output_dtype, weight, self.is_deep_gemm_supported
|
output_dtype, weight, self.is_deep_gemm_supported
|
||||||
):
|
):
|
||||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
||||||
@ -415,7 +495,6 @@ class W8A8BlockFp8LinearOp:
|
|||||||
input_2d: torch.Tensor,
|
input_2d: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Run FlashInfer FP8 block-scale GEMM.
|
Run FlashInfer FP8 block-scale GEMM.
|
||||||
@ -423,23 +502,16 @@ class W8A8BlockFp8LinearOp:
|
|||||||
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
||||||
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
||||||
"""
|
"""
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
|
# Now call FlashInfer with BF16 input + FP8 weight, input will be
|
||||||
flashinfer_block_gemm,
|
# quantized with FlashInfer kernel (W8A8)
|
||||||
)
|
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
|
||||||
|
input=input_2d, # BF16 input
|
||||||
# Quantize input dynamically if not pre-quantized (same as CUTLASS)
|
|
||||||
assert input_scale is None
|
|
||||||
assert self.input_quant_op is not None
|
|
||||||
q_input, input_scale = self.input_quant_op(input_2d)
|
|
||||||
|
|
||||||
# Now call FlashInfer with FP8 input + FP8 weight (W8A8)
|
|
||||||
return flashinfer_block_gemm(
|
|
||||||
input=q_input, # FP8 quantized input
|
|
||||||
weight=weight, # FP8 weight
|
weight=weight, # FP8 weight
|
||||||
scales_a=input_scale, # Input scales (computed dynamically)
|
weight_scale=weight_scale, # Weight scales
|
||||||
scales_b=weight_scale, # Weight scales
|
group_size=self.act_quant_group_shape.col,
|
||||||
out_dtype=input_2d.dtype,
|
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
|
||||||
)
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
def _dispatch_w8a8_blockscale_op(
|
def _dispatch_w8a8_blockscale_op(
|
||||||
self,
|
self,
|
||||||
@ -457,22 +529,6 @@ class W8A8BlockFp8LinearOp:
|
|||||||
],
|
],
|
||||||
QuantFP8 | None,
|
QuantFP8 | None,
|
||||||
]:
|
]:
|
||||||
# Prefer FlashInfer on SM90+ if available (Hopper optimized)
|
|
||||||
if (
|
|
||||||
has_flashinfer_block_gemm()
|
|
||||||
and envs.VLLM_USE_FLASHINFER_FP8_LINEAR
|
|
||||||
and not use_aiter_and_is_supported
|
|
||||||
):
|
|
||||||
logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers")
|
|
||||||
return self._run_flashinfer, (
|
|
||||||
QuantFP8(
|
|
||||||
False,
|
|
||||||
self.act_quant_group_shape,
|
|
||||||
column_major_scales=False,
|
|
||||||
use_ue8m0=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_cutlass:
|
if use_cutlass:
|
||||||
return self._run_cutlass, (
|
return self._run_cutlass, (
|
||||||
QuantFP8(
|
QuantFP8(
|
||||||
|
|||||||
@ -540,27 +540,52 @@ def flashinfer_scaled_fp8_mm(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper(
|
||||||
|
"flashinfer.gemm", "fp8_blockscale_gemm_sm90"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def has_flashinfer_block_gemm() -> bool:
|
def has_flashinfer_fp8_blockscale_gemm() -> bool:
|
||||||
"""Return `True` if FlashInfer FP8 block-scale GEMM is available."""
|
"""Return `True` if FlashInfer block-scale FP8 GEMM is available."""
|
||||||
if not has_flashinfer():
|
return has_flashinfer() and hasattr(
|
||||||
return False
|
_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90"
|
||||||
if not current_platform.is_cuda():
|
)
|
||||||
return False
|
|
||||||
# Only SM90+ (Hopper) supports this kernel
|
|
||||||
if not current_platform.is_device_capability(90):
|
@functools.cache
|
||||||
|
def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
|
||||||
|
"""Return `True` if FlashInfer block-scale FP8 GEMM is supported."""
|
||||||
|
return envs.VLLM_USE_FLASHINFER_FP8_LINEAR and has_flashinfer_fp8_blockscale_gemm()
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_flashinfer_for_block_scale_fp8_linear(
|
||||||
|
is_flashinfer_supported: bool,
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
if not is_flashinfer_supported:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
# Verify DeepGEMM N/K dims requirements
|
||||||
import flashinfer
|
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
|
||||||
|
# test inside kernels/quatization/test_block_fp8.py
|
||||||
|
N_MULTIPLE = 64
|
||||||
|
K_MULTIPLE = 128
|
||||||
|
|
||||||
# Check if the module has the required binding
|
weight_dtype = weight.dtype
|
||||||
return hasattr(flashinfer, "Fp8BlockScaleGemmRunner")
|
input_dtype = input.dtype
|
||||||
except (ImportError, AttributeError):
|
|
||||||
logger.debug_once(
|
should_use_flashinfer = (
|
||||||
"FlashInfer block-scale GEMM not available: module or binding not found"
|
output_dtype == torch.bfloat16
|
||||||
)
|
and input_dtype == torch.bfloat16
|
||||||
return False
|
and weight_dtype == torch.float8_e4m3fn
|
||||||
|
and weight.shape[0] % N_MULTIPLE == 0
|
||||||
|
and weight.shape[1] % K_MULTIPLE == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return should_use_flashinfer
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -579,11 +604,14 @@ __all__ = [
|
|||||||
"has_flashinfer_all2all",
|
"has_flashinfer_all2all",
|
||||||
"has_flashinfer_cutlass_fused_moe",
|
"has_flashinfer_cutlass_fused_moe",
|
||||||
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||||
|
"has_flashinfer_fp8_blockscale_gemm",
|
||||||
"has_nvidia_artifactory",
|
"has_nvidia_artifactory",
|
||||||
"supports_trtllm_attention",
|
"supports_trtllm_attention",
|
||||||
"can_use_trtllm_attention",
|
"can_use_trtllm_attention",
|
||||||
"use_trtllm_attention",
|
"use_trtllm_attention",
|
||||||
"flashinfer_scaled_fp4_mm",
|
"flashinfer_scaled_fp4_mm",
|
||||||
"flashinfer_scaled_fp8_mm",
|
"flashinfer_scaled_fp8_mm",
|
||||||
"has_flashinfer_block_gemm",
|
"flashinfer_fp8_blockscale_gemm",
|
||||||
|
"should_use_flashinfer_for_block_scale_fp8_linear",
|
||||||
|
"is_flashinfer_fp8_blockscale_gemm_supported",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user