mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 03:24:27 +08:00
Enable linear deepgemm_swapAB
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
This commit is contained in:
parent
09dc7c690c
commit
3d429d63a6
@ -3,8 +3,10 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Disable DeepGEMM for this benchmark to use CUTLASS
|
# Disable DeepGEMM for this benchmark
|
||||||
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||||
|
# Enable FlashInfer FP8 linear (will be used when provider="flashinfer-block-fp8")
|
||||||
|
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -94,6 +96,15 @@ plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
|
|||||||
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||||
available_providers.append("w8a8-block-fp8-cutlass")
|
available_providers.append("w8a8-block-fp8-cutlass")
|
||||||
|
|
||||||
|
# Check if FlashInfer block GEMM is available
|
||||||
|
try:
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
||||||
|
|
||||||
|
if has_flashinfer_block_gemm():
|
||||||
|
available_providers.append("flashinfer-block-fp8")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@vllm_triton.testing.perf_report(
|
@vllm_triton.testing.perf_report(
|
||||||
vllm_triton.testing.Benchmark(
|
vllm_triton.testing.Benchmark(
|
||||||
@ -134,6 +145,14 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
|
|||||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
||||||
)
|
)
|
||||||
|
elif provider == "flashinfer-block-fp8":
|
||||||
|
# Use the same W8A8 setup as other providers for fair comparison
|
||||||
|
run_w8a8_flashinfer = build_w8a8_block_fp8_runner(
|
||||||
|
M, N, K, block_size, device, use_cutlass=False
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_w8a8_flashinfer(), quantiles=quantiles
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown provider: {provider}")
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
|||||||
@ -205,3 +205,244 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flashinfer_block_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
|
"""
|
||||||
|
Test FlashInfer FP8 block-scale GEMM through W8A8BlockFp8LinearOp.
|
||||||
|
|
||||||
|
This tests the FP8 + FP8 → BF16 path (W8A8 full quantization).
|
||||||
|
Matches TensorRT-LLM's test_fp8_block_scale_gemm behavior.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
||||||
|
|
||||||
|
if not has_flashinfer_block_gemm():
|
||||||
|
pytest.skip(
|
||||||
|
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip tests for dimensions that don't have pre-compiled kernels in FlashInfer
|
||||||
|
# These cause CUDA runtime errors
|
||||||
|
if K == 3884 or N == 7748:
|
||||||
|
pytest.skip(f"FlashInfer does not have pre-compiled kernels for K={K} or N={N}")
|
||||||
|
|
||||||
|
# Enable FlashInfer backend (required for W8A8BlockFp8LinearOp to use FlashInfer)
|
||||||
|
os.environ["VLLM_USE_FLASHINFER_FP8_LINEAR"] = "1"
|
||||||
|
# Reload envs module to pick up the env var change
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
|
||||||
|
importlib.reload(envs)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
W8A8BlockFp8LinearOp,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
# Create BF16 inputs (normalized like TRT-LLM)
|
||||||
|
A_bf16 = torch.randn(M, K, dtype=torch.bfloat16) / K
|
||||||
|
B_bf16 = torch.randn(N, K, dtype=torch.bfloat16) / K
|
||||||
|
|
||||||
|
# Quantize weight with per-block scales
|
||||||
|
B_fp8, Bs = per_block_cast_to_fp8(B_bf16, block_size=block_size)
|
||||||
|
|
||||||
|
# Create W8A8BlockFp8LinearOp to handle input quantization
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
weight_group_shape = GroupShape(block_n, block_k)
|
||||||
|
act_quant_group_shape = GroupShape(1, block_k) # Per-token quantization
|
||||||
|
|
||||||
|
linear_op = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=weight_group_shape,
|
||||||
|
act_quant_group_shape=act_quant_group_shape,
|
||||||
|
cutlass_block_fp8_supported=False, # Disable CUTLASS
|
||||||
|
use_aiter_and_is_supported=False, # Disable AITER
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify FlashInfer backend is selected
|
||||||
|
assert linear_op.w8a8_blockscale_op == linear_op._run_flashinfer, (
|
||||||
|
"FlashInfer backend not selected! "
|
||||||
|
"Make sure VLLM_USE_FLASHINFER_FP8_LINEAR=1 is set before running tests."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute reference: BF16 × BF16 matmul (before quantization)
|
||||||
|
ref_out = torch.matmul(A_bf16, B_bf16.T)
|
||||||
|
|
||||||
|
# Run W8A8 FlashInfer GEMM (input will be quantized internally)
|
||||||
|
out = linear_op.apply(
|
||||||
|
input=A_bf16,
|
||||||
|
weight=B_fp8,
|
||||||
|
weight_scale=Bs,
|
||||||
|
input_scale=None, # Will quantize dynamically
|
||||||
|
bias=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results using TensorRT-LLM's calc_diff metric
|
||||||
|
# This measures normalized similarity: sim = 2*<x,y> / (||x||² + ||y||²)
|
||||||
|
out_fp64 = out.to(torch.float64)
|
||||||
|
ref_fp64 = ref_out.to(torch.float64)
|
||||||
|
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
|
||||||
|
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
|
||||||
|
diff = 1 - sim
|
||||||
|
|
||||||
|
# W8A8 threshold from TensorRT-LLM: diff < 0.001 (99.9% similarity)
|
||||||
|
assert diff < 0.001, (
|
||||||
|
f"Similarity difference {diff:.6f} exceeds threshold (similarity: {sim:.6f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"M,N,K,block_size,seed",
|
||||||
|
[
|
||||||
|
(1, 1024, 4096, [128, 128], 0),
|
||||||
|
(32, 4096, 512, [128, 128], 0),
|
||||||
|
(128, 1024, 4096, [128, 128], 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_dtype,weight_dtype",
|
||||||
|
[
|
||||||
|
(torch.bfloat16, torch.bfloat16), # BF16 + BF16 (internal quantization)
|
||||||
|
(torch.bfloat16, torch.float8_e4m3fn), # BF16 + FP8 (weight-only)
|
||||||
|
(torch.float8_e4m3fn, torch.float8_e4m3fn), # FP8 + FP8 (W8A8)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flashinfer_block_gemm_dtypes(
|
||||||
|
M, N, K, block_size, input_dtype, weight_dtype, seed
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test all three supported dtype combinations for FlashInfer FP8 block-scale GEMM.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- BF16 + BF16 → BF16: Both inputs BF16, internal quantization
|
||||||
|
- BF16 + FP8 → BF16: Weight-only quantization
|
||||||
|
- FP8 + FP8 → BF16: W8A8 full quantization
|
||||||
|
|
||||||
|
This mirrors FlashInfer's own test_fp8_blockscale_gemm_dtypes and TRT-LLM's tests.
|
||||||
|
"""
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
||||||
|
|
||||||
|
if not has_flashinfer_block_gemm():
|
||||||
|
pytest.skip(
|
||||||
|
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
|
||||||
|
flashinfer_block_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add debug output to verify test execution
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"TEST: M={M}, N={N}, K={K} | Input: {input_dtype}, Weight: {weight_dtype}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
# Create BF16 data for reference (same as FlashInfer tests)
|
||||||
|
input_bf16 = torch.randn(M, K, dtype=torch.bfloat16)
|
||||||
|
weight_bf16 = torch.randn(N, K, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Quantize input based on dtype
|
||||||
|
if input_dtype == torch.float8_e4m3fn:
|
||||||
|
input_tensor, input_scale = per_token_group_quant_fp8(input_bf16, block_size[1])
|
||||||
|
else:
|
||||||
|
input_tensor, input_scale = input_bf16, None
|
||||||
|
|
||||||
|
# Quantize weight based on dtype
|
||||||
|
if weight_dtype == torch.float8_e4m3fn:
|
||||||
|
weight_tensor, weight_scale = per_block_cast_to_fp8(
|
||||||
|
weight_bf16, block_size=block_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
weight_tensor, weight_scale = weight_bf16, None
|
||||||
|
|
||||||
|
# Run FlashInfer FP8 block-scale GEMM
|
||||||
|
output = flashinfer_block_gemm(
|
||||||
|
input=input_tensor,
|
||||||
|
weight=weight_tensor,
|
||||||
|
scales_a=input_scale,
|
||||||
|
scales_b=weight_scale,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify output properties
|
||||||
|
assert output.shape == (M, N), f"Expected shape {(M, N)}, got {output.shape}"
|
||||||
|
assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}"
|
||||||
|
|
||||||
|
# Compute reference based on dtype combination
|
||||||
|
if input_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn:
|
||||||
|
# W8A8: Compare against dequantized FP8 reference (tests kernel correctness)
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
|
||||||
|
input_dequant = torch.zeros_like(input_bf16)
|
||||||
|
for i in range(M):
|
||||||
|
for k_tile in range(k_tiles):
|
||||||
|
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
|
||||||
|
input_dequant[i, start:end] = (
|
||||||
|
input_tensor[i, start:end].to(torch.bfloat16)
|
||||||
|
* input_scale[i, k_tile]
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_dequant = torch.zeros_like(weight_bf16)
|
||||||
|
for j in range(N):
|
||||||
|
for k_tile in range(k_tiles):
|
||||||
|
start, end = k_tile * block_k, min((k_tile + 1) * block_k, K)
|
||||||
|
weight_dequant[j, start:end] = (
|
||||||
|
weight_tensor[j, start:end].to(torch.bfloat16)
|
||||||
|
* weight_scale[j // block_n, k_tile]
|
||||||
|
)
|
||||||
|
|
||||||
|
reference = torch.matmul(input_dequant, weight_dequant.T)
|
||||||
|
|
||||||
|
# W8A8: Use TRT-LLM's calc_diff metric with strict threshold
|
||||||
|
out_fp64 = output.to(torch.float64)
|
||||||
|
ref_fp64 = reference.to(torch.float64)
|
||||||
|
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
|
||||||
|
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
|
||||||
|
diff = 1 - sim
|
||||||
|
|
||||||
|
# W8A8 achieves very high accuracy: diff < 0.001 (99.9% similarity)
|
||||||
|
assert diff < 0.001, (
|
||||||
|
f"W8A8 similarity difference {diff:.6f} too high (expected < 0.001, similarity: {sim:.6f})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# BF16+BF16 or BF16+FP8: Compare against original BF16 reference
|
||||||
|
reference = torch.matmul(input_bf16, weight_bf16.T)
|
||||||
|
|
||||||
|
out_fp64 = output.to(torch.float64)
|
||||||
|
ref_fp64 = reference.to(torch.float64)
|
||||||
|
denominator = (out_fp64 * out_fp64 + ref_fp64 * ref_fp64).sum()
|
||||||
|
sim = 2 * (out_fp64 * ref_fp64).sum() / denominator
|
||||||
|
diff = 1 - sim
|
||||||
|
|
||||||
|
if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16:
|
||||||
|
# BF16+BF16: Highest accuracy (internal quantization)
|
||||||
|
threshold = 0.001
|
||||||
|
threshold_desc = "0.1%"
|
||||||
|
elif input_dtype == torch.bfloat16 and weight_dtype == torch.float8_e4m3fn:
|
||||||
|
# BF16+FP8: Weight-only quantization, higher error
|
||||||
|
threshold = 0.01
|
||||||
|
threshold_desc = "1%"
|
||||||
|
else:
|
||||||
|
# Other combinations
|
||||||
|
threshold = 0.01
|
||||||
|
threshold_desc = "1%"
|
||||||
|
|
||||||
|
assert diff < threshold, (
|
||||||
|
f"Similarity difference {diff:.6f} too high for "
|
||||||
|
f"{input_dtype} + {weight_dtype} (expected < {threshold_desc}, similarity: {sim:.6f})"
|
||||||
|
)
|
||||||
|
|||||||
@ -168,6 +168,7 @@ if TYPE_CHECKING:
|
|||||||
"relax",
|
"relax",
|
||||||
] = "relax"
|
] = "relax"
|
||||||
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
|
||||||
|
VLLM_USE_FLASHINFER_FP8_LINEAR: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||||
@ -1208,6 +1209,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
|
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
|
||||||
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
|
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
|
||||||
),
|
),
|
||||||
|
# Allow use of FlashInfer FP8 block-scale GEMM for linear layers.
|
||||||
|
# This uses TensorRT-LLM kernels and requires SM90+ (Hopper).
|
||||||
|
"VLLM_USE_FLASHINFER_FP8_LINEAR": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_USE_FLASHINFER_FP8_LINEAR", "0"))
|
||||||
|
),
|
||||||
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
# Allow use of FlashInfer MoE kernels for fused moe ops.
|
||||||
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
|
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
|
||||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))
|
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))
|
||||||
|
|||||||
@ -0,0 +1,57 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
FlashInfer FP8 Block-Scale GEMM wrapper for vLLM.
|
||||||
|
|
||||||
|
This module provides a thin wrapper around FlashInfer's FP8 block-scale GEMM
|
||||||
|
implementation, which uses TensorRT-LLM's optimized kernels for NVIDIA Hopper (SM90+).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_block_gemm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scales_a: torch.Tensor | None,
|
||||||
|
scales_b: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Wrapper for FlashInfer's FP8 block-scale GEMM.
|
||||||
|
|
||||||
|
Computes: output = (input @ weight.T) with per-block scaling for quantization.
|
||||||
|
|
||||||
|
Supports three modes:
|
||||||
|
1. BF16 + BF16 → BF16: Both inputs BF16, internal quantization (scales_a=None, scales_b used internally)
|
||||||
|
2. BF16 + FP8 → BF16: Weight-only quantization (scales_a=None, scales_b for weight)
|
||||||
|
3. FP8 + FP8 → BF16: W8A8 full quantization (scales_a for input, scales_b for weight)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor (M, K) - BF16 or FP8 e4m3
|
||||||
|
weight: Weight tensor (N, K) - BF16 or FP8 e4m3
|
||||||
|
scales_a: Input scales (M, K//block_k) or None - FP32
|
||||||
|
None: input is BF16 (will be quantized internally for BF16+BF16 or left as-is for BF16+FP8)
|
||||||
|
Provided: input is pre-quantized FP8 (W8A8 mode)
|
||||||
|
scales_b: Weight scales (N//block_n, K//block_k) - FP32
|
||||||
|
out_dtype: Output dtype (typically torch.bfloat16)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: Result tensor (M, N) in out_dtype
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Requires SM90+ GPU (NVIDIA Hopper)
|
||||||
|
- Uses TensorRT-LLM's optimized CUTLASS kernels via FlashInfer
|
||||||
|
- For M < 32, automatically uses SwapAB kernel optimization
|
||||||
|
"""
|
||||||
|
from flashinfer.gemm import fp8_blockscale_gemm_swapab
|
||||||
|
|
||||||
|
return fp8_blockscale_gemm_swapab(
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
input_scale=scales_a,
|
||||||
|
weight_scale=scales_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
@ -35,6 +35,7 @@ from vllm.utils.deep_gemm import (
|
|||||||
should_use_deepgemm_for_fp8_linear,
|
should_use_deepgemm_for_fp8_linear,
|
||||||
transform_sf_into_required_layout,
|
transform_sf_into_required_layout,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_block_gemm
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -409,6 +410,37 @@ class W8A8BlockFp8LinearOp:
|
|||||||
input_2d.dtype,
|
input_2d.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_flashinfer(
|
||||||
|
self,
|
||||||
|
input_2d: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Run FlashInfer FP8 block-scale GEMM.
|
||||||
|
|
||||||
|
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
|
||||||
|
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
|
||||||
|
"""
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_block_gemm import (
|
||||||
|
flashinfer_block_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quantize input dynamically if not pre-quantized (same as CUTLASS)
|
||||||
|
assert input_scale is None
|
||||||
|
assert self.input_quant_op is not None
|
||||||
|
q_input, input_scale = self.input_quant_op(input_2d)
|
||||||
|
|
||||||
|
# Now call FlashInfer with FP8 input + FP8 weight (W8A8)
|
||||||
|
return flashinfer_block_gemm(
|
||||||
|
input=q_input, # FP8 quantized input
|
||||||
|
weight=weight, # FP8 weight
|
||||||
|
scales_a=input_scale, # Input scales (computed dynamically)
|
||||||
|
scales_b=weight_scale, # Weight scales
|
||||||
|
out_dtype=input_2d.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def _dispatch_w8a8_blockscale_op(
|
def _dispatch_w8a8_blockscale_op(
|
||||||
self,
|
self,
|
||||||
use_cutlass: bool,
|
use_cutlass: bool,
|
||||||
@ -425,6 +457,22 @@ class W8A8BlockFp8LinearOp:
|
|||||||
],
|
],
|
||||||
QuantFP8 | None,
|
QuantFP8 | None,
|
||||||
]:
|
]:
|
||||||
|
# Prefer FlashInfer on SM90+ if available (Hopper optimized)
|
||||||
|
if (
|
||||||
|
has_flashinfer_block_gemm()
|
||||||
|
and envs.VLLM_USE_FLASHINFER_FP8_LINEAR
|
||||||
|
and not use_aiter_and_is_supported
|
||||||
|
):
|
||||||
|
logger.info_once("Using FlashInfer FP8 block-scale GEMM for linear layers")
|
||||||
|
return self._run_flashinfer, (
|
||||||
|
QuantFP8(
|
||||||
|
False,
|
||||||
|
self.act_quant_group_shape,
|
||||||
|
column_major_scales=False,
|
||||||
|
use_ue8m0=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if use_cutlass:
|
if use_cutlass:
|
||||||
return self._run_cutlass, (
|
return self._run_cutlass, (
|
||||||
QuantFP8(
|
QuantFP8(
|
||||||
|
|||||||
@ -540,6 +540,29 @@ def flashinfer_scaled_fp8_mm(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_block_gemm() -> bool:
|
||||||
|
"""Return `True` if FlashInfer FP8 block-scale GEMM is available."""
|
||||||
|
if not has_flashinfer():
|
||||||
|
return False
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
# Only SM90+ (Hopper) supports this kernel
|
||||||
|
if not current_platform.is_device_capability(90):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
# Check if the module has the required binding
|
||||||
|
return hasattr(flashinfer, "Fp8BlockScaleGemmRunner")
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
logger.debug_once(
|
||||||
|
"FlashInfer block-scale GEMM not available: module or binding not found"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"has_flashinfer",
|
"has_flashinfer",
|
||||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||||
@ -562,4 +585,5 @@ __all__ = [
|
|||||||
"use_trtllm_attention",
|
"use_trtllm_attention",
|
||||||
"flashinfer_scaled_fp4_mm",
|
"flashinfer_scaled_fp4_mm",
|
||||||
"flashinfer_scaled_fp8_mm",
|
"flashinfer_scaled_fp8_mm",
|
||||||
|
"has_flashinfer_block_gemm",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user