mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
1130 lines
39 KiB
Python
1130 lines
39 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""
|
|
Benchmark for FlashInfer fused collective operations vs standard operations.
|
|
|
|
This benchmark compares:
|
|
1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant)
|
|
2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations
|
|
|
|
Usage with torchrun:
|
|
torchrun --nproc_per_node=2 benchmark_fused_collective.py
|
|
|
|
"""
|
|
|
|
import argparse
|
|
import itertools
|
|
import os
|
|
import time
|
|
|
|
import pandas as pd
|
|
import torch # type: ignore
|
|
import torch.distributed as dist # type: ignore
|
|
|
|
from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
|
|
from vllm.distributed import (
|
|
get_tp_group,
|
|
tensor_model_parallel_all_reduce,
|
|
)
|
|
from vllm.distributed.parallel_state import (
|
|
graph_capture,
|
|
init_distributed_environment,
|
|
initialize_model_parallel,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.layernorm import RMSNorm # noqa
|
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa
|
|
from vllm.platforms import current_platform # noqa
|
|
|
|
RMS_NORM_OP = torch.ops._C.rms_norm
|
|
FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm
|
|
RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant
|
|
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = (
|
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant
|
|
)
|
|
SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Try to import FlashInfer
|
|
try:
|
|
import flashinfer.comm as flashinfer_comm # type: ignore
|
|
|
|
if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"):
|
|
flashinfer_comm = None
|
|
logger.warning(
|
|
"FlashInfer comm module found but missing trtllm_allreduce_fusion"
|
|
)
|
|
except ImportError:
|
|
flashinfer_comm = None
|
|
logger.warning("FlashInfer not found, only benchmarking standard operations")
|
|
|
|
# Constants
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
MiB = 1024 * 1024
|
|
|
|
# FlashInfer max sizes per world size
|
|
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
|
|
# use --disable-oneshot to disable oneshot mode for very large input sizes
|
|
_FI_MAX_SIZES = {
|
|
2: 64 * MiB, # 64MB
|
|
4: 64 * MiB, # 64MB
|
|
8: 64 * MiB, # 64MB
|
|
}
|
|
|
|
# Global workspace tensor for FlashInfer
|
|
_FI_WORKSPACE_TENSOR = None
|
|
|
|
|
|
def setup_flashinfer_workspace(
|
|
world_size: int,
|
|
rank: int,
|
|
hidden_dim: int,
|
|
max_token_num: int,
|
|
use_fp32_lamport: bool = False,
|
|
):
|
|
"""Setup FlashInfer workspace for fused allreduce operations."""
|
|
global _FI_WORKSPACE_TENSOR
|
|
|
|
if flashinfer_comm is None:
|
|
return None, None
|
|
|
|
if world_size not in _FI_MAX_SIZES:
|
|
logger.warning("FlashInfer not supported for world size %s", world_size)
|
|
return None, None
|
|
|
|
try:
|
|
# Create IPC workspace
|
|
ipc_handles, workspace_tensor = (
|
|
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
|
tp_rank=rank,
|
|
tp_size=world_size,
|
|
max_token_num=max_token_num,
|
|
hidden_dim=hidden_dim,
|
|
group=get_tp_group().device_group,
|
|
use_fp32_lamport=use_fp32_lamport,
|
|
)
|
|
)
|
|
|
|
_FI_WORKSPACE_TENSOR = workspace_tensor
|
|
return ipc_handles, workspace_tensor
|
|
except Exception as e:
|
|
logger.error("Failed to setup FlashInfer workspace: %s", e)
|
|
return None, None
|
|
|
|
|
|
def cleanup_flashinfer_workspace(ipc_handles):
|
|
"""Cleanup FlashInfer workspace."""
|
|
if flashinfer_comm is None or ipc_handles is None:
|
|
return
|
|
|
|
try:
|
|
group = get_tp_group().device_group
|
|
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group)
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup FlashInfer workspace: %s", e)
|
|
|
|
|
|
class FlashInferFusedAllReduceParams:
|
|
"""Parameters for FlashInfer fused allreduce operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
rank: int,
|
|
world_size: int,
|
|
use_fp32_lamport: bool = False,
|
|
max_token_num: int = 1024,
|
|
):
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.use_fp32_lamport = use_fp32_lamport
|
|
self.trigger_completion_at_end = True
|
|
self.launch_with_pdl = True
|
|
self.fp32_acc = True
|
|
self.max_token_num = max_token_num
|
|
|
|
def get_trtllm_fused_allreduce_kwargs(self):
|
|
return {
|
|
"world_rank": self.rank,
|
|
"world_size": self.world_size,
|
|
"launch_with_pdl": self.launch_with_pdl,
|
|
"trigger_completion_at_end": self.trigger_completion_at_end,
|
|
"fp32_acc": self.fp32_acc,
|
|
}
|
|
|
|
|
|
def flashinfer_fused_allreduce_rmsnorm(
|
|
input_tensor: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
rms_gamma: torch.Tensor,
|
|
rms_eps: float,
|
|
allreduce_params: "FlashInferFusedAllReduceParams",
|
|
use_oneshot: bool,
|
|
norm_out: torch.Tensor | None = None,
|
|
):
|
|
"""FlashInfer fused allreduce + rmsnorm operation."""
|
|
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
|
|
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
|
|
|
if norm_out is None:
|
|
norm_out = input_tensor
|
|
residual_out = residual
|
|
else:
|
|
residual_out = input_tensor
|
|
|
|
flashinfer_comm.trtllm_allreduce_fusion(
|
|
allreduce_in=input_tensor,
|
|
token_num=input_tensor.shape[0],
|
|
residual_in=residual,
|
|
residual_out=residual_out,
|
|
norm_out=norm_out,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
hidden_dim=input_tensor.shape[-1],
|
|
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
|
allreduce_out=None,
|
|
quant_out=None,
|
|
scale_out=None,
|
|
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
|
scale_factor=None,
|
|
use_oneshot=use_oneshot,
|
|
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
|
)
|
|
|
|
|
|
def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
|
|
input_tensor: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
rms_gamma: torch.Tensor,
|
|
rms_eps: float,
|
|
scale_factor: torch.Tensor,
|
|
allreduce_params: FlashInferFusedAllReduceParams,
|
|
use_oneshot: bool = True,
|
|
norm_out: torch.Tensor | None = None,
|
|
quant_out: torch.Tensor | None = None,
|
|
):
|
|
"""FlashInfer fused allreduce + rmsnorm + FP8 quantization."""
|
|
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
|
|
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
|
|
|
if norm_out is None:
|
|
norm_out = input_tensor
|
|
residual_out = residual
|
|
else:
|
|
residual_out = input_tensor
|
|
|
|
flashinfer_comm.trtllm_allreduce_fusion(
|
|
allreduce_in=input_tensor,
|
|
token_num=input_tensor.shape[0],
|
|
residual_in=residual,
|
|
residual_out=residual_out,
|
|
norm_out=norm_out,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
hidden_dim=input_tensor.shape[-1],
|
|
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant,
|
|
allreduce_out=None,
|
|
quant_out=quant_out,
|
|
scale_out=None,
|
|
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
|
scale_factor=scale_factor,
|
|
use_oneshot=use_oneshot,
|
|
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
|
)
|
|
|
|
|
|
def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
|
|
input_tensor: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
rms_gamma: torch.Tensor,
|
|
rms_eps: float,
|
|
input_global_scale: torch.Tensor,
|
|
allreduce_params: FlashInferFusedAllReduceParams,
|
|
quant_out: torch.Tensor,
|
|
use_oneshot: bool,
|
|
output_scale: torch.Tensor,
|
|
norm_out: torch.Tensor | None = None,
|
|
):
|
|
"""FlashInfer fused allreduce + rmsnorm + FP4 quantization."""
|
|
if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None:
|
|
raise RuntimeError("FlashInfer not available or workspace not initialized")
|
|
|
|
if norm_out is None:
|
|
norm_out = input_tensor
|
|
residual_out = residual
|
|
else:
|
|
residual_out = input_tensor
|
|
|
|
flashinfer_comm.trtllm_allreduce_fusion(
|
|
allreduce_in=input_tensor,
|
|
token_num=input_tensor.shape[0],
|
|
residual_in=residual,
|
|
residual_out=residual_out,
|
|
norm_out=norm_out,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
hidden_dim=input_tensor.shape[-1],
|
|
workspace_ptrs=_FI_WORKSPACE_TENSOR,
|
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant,
|
|
allreduce_out=None,
|
|
quant_out=quant_out,
|
|
scale_out=output_scale,
|
|
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
|
|
scale_factor=input_global_scale,
|
|
use_oneshot=use_oneshot,
|
|
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
|
)
|
|
|
|
|
|
class VllmFusedAllreduce:
|
|
def __init__(self, hidden_dim, dtype):
|
|
self.rms_eps = 1e-6
|
|
self.rms_norm = RMSNorm(hidden_dim, eps=self.rms_eps, dtype=dtype)
|
|
self.fp8_quant = QuantFP8(
|
|
static=True,
|
|
group_shape=GroupShape.PER_TENSOR,
|
|
)
|
|
|
|
def allreduce_rmsnorm(
|
|
self, input_tensor: torch.Tensor, residual: torch.Tensor | None
|
|
):
|
|
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
|
|
return self.rms_norm(allreduce_out, residual)
|
|
|
|
def allreduce_rmsnorm_fp8_quant(
|
|
self,
|
|
input_tensor: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
scale_factor: torch.Tensor,
|
|
):
|
|
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
|
|
rms_out = self.rms_norm(allreduce_out, residual)
|
|
if residual is None:
|
|
quant_out = self.fp8_quant(rms_out, scale_factor)
|
|
return quant_out
|
|
else:
|
|
rms_out, residual_out = rms_out
|
|
quant_out = self.fp8_quant(rms_out, scale_factor)
|
|
return quant_out, residual_out
|
|
|
|
def allreduce_rmsnorm_fp4_quant(
|
|
self,
|
|
input_tensor: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
input_global_scale: torch.Tensor,
|
|
quant_out: torch.Tensor,
|
|
output_scale: torch.Tensor,
|
|
):
|
|
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
|
|
rms_out = self.rms_norm(allreduce_out, residual)
|
|
if residual is None:
|
|
SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale)
|
|
return quant_out, output_scale
|
|
else:
|
|
rms_out, residual_out = rms_out
|
|
SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale)
|
|
return quant_out, residual_out, output_scale
|
|
|
|
|
|
def create_test_tensors(
|
|
num_tokens: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True
|
|
):
|
|
"""Create test tensors for benchmarking."""
|
|
input_tensor = torch.randn(num_tokens, hidden_dim, dtype=dtype)
|
|
residual = (
|
|
torch.randn_like(input_tensor)
|
|
if use_residual
|
|
else torch.zeros_like(input_tensor)
|
|
)
|
|
rms_gamma = torch.ones(hidden_dim, dtype=dtype)
|
|
norm_out = None if use_residual else torch.empty_like(input_tensor)
|
|
|
|
# Quantization scales
|
|
scale_fp8 = torch.tensor(1.0, dtype=torch.float32)
|
|
scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
|
|
quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
|
|
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
|
|
fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8)
|
|
fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)
|
|
|
|
return (
|
|
input_tensor,
|
|
norm_out,
|
|
residual,
|
|
rms_gamma,
|
|
scale_fp8,
|
|
quant_out_fp8,
|
|
scale_fp4,
|
|
fp4_quant_out,
|
|
fp4_output_scale,
|
|
)
|
|
|
|
|
|
def benchmark_operation(
|
|
operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs
|
|
):
|
|
"""Benchmark a single operation using CUDA graphs."""
|
|
# Warmup before graph capture
|
|
for _ in range(warmup):
|
|
operation_func(*args, **kwargs)
|
|
torch.cuda.synchronize()
|
|
|
|
# Create CUDA graph
|
|
graph = torch.cuda.CUDAGraph()
|
|
num_op_per_cudagraph = 10
|
|
|
|
# Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe
|
|
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
|
with graph_capture(device=device), torch.cuda.graph(graph):
|
|
for _ in range(num_op_per_cudagraph):
|
|
operation_func(*args, **kwargs)
|
|
|
|
# Graph warmup
|
|
torch.cuda.synchronize()
|
|
for _ in range(warmup):
|
|
graph.replay()
|
|
|
|
# Benchmark with CUDA graph
|
|
torch.cuda.synchronize()
|
|
start_time = time.perf_counter()
|
|
|
|
for _ in range(trials // num_op_per_cudagraph):
|
|
# operation_func(*args, **kwargs)
|
|
graph.replay()
|
|
|
|
torch.cuda.synchronize()
|
|
end_time = time.perf_counter()
|
|
|
|
avg_time_ms = ((end_time - start_time) / trials) * 1000
|
|
return avg_time_ms
|
|
|
|
|
|
def run_benchmarks(
|
|
num_tokens: int,
|
|
hidden_dim: int,
|
|
dtype: torch.dtype,
|
|
use_residual: bool,
|
|
allreduce_params: FlashInferFusedAllReduceParams | None,
|
|
quant_modes: set[str],
|
|
no_oneshot: bool,
|
|
):
|
|
"""Run all benchmarks for given configuration.
|
|
|
|
Args:
|
|
quant_mode: "none", "fp8_only", "fp4_only", or "all"
|
|
"""
|
|
(
|
|
input_tensor,
|
|
norm_out,
|
|
residual,
|
|
rms_gamma,
|
|
scale_fp8,
|
|
quant_out_fp8,
|
|
scale_fp4,
|
|
fp4_quant_out,
|
|
fp4_output_scale,
|
|
) = create_test_tensors(num_tokens, hidden_dim, dtype, use_residual)
|
|
|
|
rms_eps = 1e-6
|
|
results = {}
|
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
|
use_oneshot_options = [False] if no_oneshot else [True, False]
|
|
|
|
# Create RMSNorm and QuantFP8 layers once for native benchmarks
|
|
|
|
if "none" in quant_modes:
|
|
# Standard AllReduce + RMSNorm
|
|
for custom_op in ["-rms_norm", "+rms_norm"]:
|
|
with set_current_vllm_config(
|
|
VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
|
|
):
|
|
try:
|
|
suffix = (
|
|
"_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
|
|
)
|
|
time_ms = benchmark_operation(
|
|
vllm_fused_allreduce.allreduce_rmsnorm,
|
|
input_tensor,
|
|
residual=residual,
|
|
)
|
|
results[f"standard_allreduce_{suffix}"] = time_ms
|
|
except Exception as e:
|
|
logger.error("Standard AllReduce+RMSNorm failed: %s", e)
|
|
results[f"standard_allreduce_{suffix}"] = float("inf")
|
|
|
|
# Standard AllReduce + RMSNorm Native Compiled
|
|
with set_current_vllm_config(
|
|
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
|
|
):
|
|
try:
|
|
standard_allreduce_rmsnorm_native_compiled = torch.compile(
|
|
vllm_fused_allreduce.allreduce_rmsnorm,
|
|
fullgraph=True,
|
|
dynamic=False,
|
|
)
|
|
time_ms = benchmark_operation(
|
|
standard_allreduce_rmsnorm_native_compiled,
|
|
input_tensor,
|
|
residual=residual,
|
|
)
|
|
results["standard_allreduce_rmsnorm_native_compiled"] = time_ms
|
|
except Exception as e:
|
|
logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e)
|
|
results["standard_allreduce_rmsnorm_native_compiled"] = float("inf")
|
|
|
|
# FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot
|
|
if flashinfer_comm is not None and allreduce_params is not None:
|
|
for use_oneshot in use_oneshot_options:
|
|
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
|
try:
|
|
time_ms = benchmark_operation(
|
|
flashinfer_fused_allreduce_rmsnorm,
|
|
input_tensor,
|
|
residual=residual,
|
|
norm_out=norm_out,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
allreduce_params=allreduce_params,
|
|
use_oneshot=use_oneshot,
|
|
)
|
|
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms
|
|
except Exception as e:
|
|
logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e)
|
|
results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float(
|
|
"inf"
|
|
)
|
|
|
|
if "fp8" in quant_modes:
|
|
# Standard AllReduce + RMSNorm + FP8 Quant
|
|
for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]:
|
|
suffix = (
|
|
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
|
|
)
|
|
for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
|
|
suffix += (
|
|
"_custom_quant_fp8"
|
|
if "+" in quant_fp8_custom_op
|
|
else "_native_quant_fp8"
|
|
)
|
|
with set_current_vllm_config(
|
|
VllmConfig(
|
|
compilation_config=CompilationConfig(
|
|
custom_ops=[rms_norm_custom_op, quant_fp8_custom_op]
|
|
)
|
|
)
|
|
):
|
|
try:
|
|
time_ms = benchmark_operation(
|
|
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
|
|
input_tensor,
|
|
residual=residual,
|
|
scale_factor=scale_fp8,
|
|
)
|
|
results[f"standard_allreduce{suffix}"] = time_ms
|
|
except Exception as e:
|
|
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
|
|
results[f"standard_allreduce{suffix}"] = float("inf")
|
|
|
|
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
|
|
with set_current_vllm_config(
|
|
VllmConfig(
|
|
compilation_config=CompilationConfig(
|
|
custom_ops=["-rms_norm", "-quant_fp8"]
|
|
)
|
|
)
|
|
):
|
|
try:
|
|
standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
|
|
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
|
|
fullgraph=True,
|
|
dynamic=False,
|
|
)
|
|
time_ms = benchmark_operation(
|
|
standard_allreduce_rmsnorm_fp8_quant_native_compiled,
|
|
input_tensor,
|
|
residual=residual,
|
|
scale_factor=scale_fp8,
|
|
)
|
|
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = (
|
|
time_ms
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e
|
|
)
|
|
results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float(
|
|
"inf"
|
|
)
|
|
|
|
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
|
|
if flashinfer_comm is not None and allreduce_params is not None:
|
|
for use_oneshot in use_oneshot_options:
|
|
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
|
try:
|
|
time_ms = benchmark_operation(
|
|
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
|
|
input_tensor,
|
|
norm_out=norm_out,
|
|
residual=residual,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
scale_factor=scale_fp8,
|
|
quant_out=quant_out_fp8,
|
|
allreduce_params=allreduce_params,
|
|
use_oneshot=use_oneshot,
|
|
)
|
|
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
|
|
time_ms
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
|
|
e,
|
|
)
|
|
results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = (
|
|
float("inf")
|
|
)
|
|
|
|
if "fp4" in quant_modes and current_platform.has_device_capability(100):
|
|
# Standard AllReduce + RMSNorm + FP4 Quant
|
|
for rms_norm_custom_op in ["-rms_norm", "+rms_norm"]:
|
|
suffix = (
|
|
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
|
|
)
|
|
with set_current_vllm_config(
|
|
VllmConfig(
|
|
compilation_config=CompilationConfig(
|
|
custom_ops=[rms_norm_custom_op]
|
|
)
|
|
)
|
|
):
|
|
try:
|
|
time_ms = benchmark_operation(
|
|
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
|
|
input_tensor,
|
|
residual=residual,
|
|
input_global_scale=scale_fp4,
|
|
quant_out=fp4_quant_out,
|
|
output_scale=fp4_output_scale,
|
|
)
|
|
results[f"standard_allreduce_{suffix}_fp4_quant"] = time_ms
|
|
except Exception as e:
|
|
logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e)
|
|
results[f"standard_allreduce_{suffix}_fp4_quant"] = float("inf")
|
|
|
|
# Standard AllReduce + RMSNorm + FP4 Quant Native Compiled
|
|
with set_current_vllm_config(
|
|
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
|
|
):
|
|
try:
|
|
standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
|
|
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
|
|
fullgraph=True,
|
|
dynamic=False,
|
|
)
|
|
time_ms = benchmark_operation(
|
|
standard_allreduce_rmsnorm_fp4_quant_native_compiled,
|
|
input_tensor,
|
|
residual=residual,
|
|
quant_out=fp4_quant_out,
|
|
input_global_scale=scale_fp4,
|
|
output_scale=fp4_output_scale,
|
|
)
|
|
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = (
|
|
time_ms
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e
|
|
)
|
|
results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float(
|
|
"inf"
|
|
)
|
|
|
|
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
|
|
if flashinfer_comm is not None and allreduce_params is not None:
|
|
for use_oneshot in use_oneshot_options:
|
|
suffix = "_oneshot" if use_oneshot else "_twoshot"
|
|
try:
|
|
time_ms = benchmark_operation(
|
|
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
|
|
input_tensor,
|
|
residual=residual,
|
|
norm_out=norm_out,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
input_global_scale=scale_fp4,
|
|
allreduce_params=allreduce_params,
|
|
quant_out=fp4_quant_out,
|
|
output_scale=fp4_output_scale,
|
|
use_oneshot=use_oneshot,
|
|
)
|
|
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
|
|
time_ms
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
|
|
e,
|
|
)
|
|
results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = (
|
|
float("inf")
|
|
)
|
|
|
|
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot
|
|
if flashinfer_comm is not None and allreduce_params is not None:
|
|
try:
|
|
time_ms = benchmark_operation(
|
|
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
|
|
input_tensor,
|
|
residual=residual,
|
|
norm_out=norm_out,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
input_global_scale=scale_fp4,
|
|
allreduce_params=allreduce_params,
|
|
quant_out=fp4_quant_out,
|
|
output_scale=fp4_output_scale,
|
|
use_oneshot=False,
|
|
)
|
|
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = (
|
|
time_ms
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s",
|
|
e,
|
|
)
|
|
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float(
|
|
"inf"
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
def prepare_results_with_speedups(results_dict):
|
|
"""Prepare results with speedup calculations based on dynamic baseline selection."""
|
|
prepared_results = []
|
|
|
|
# Determine the fastest baseline for each operation type
|
|
def get_fastest_baseline(op_name, results_dict):
|
|
"""Get the fastest baseline between standard and native_compiled versions."""
|
|
if "fp8_quant" in op_name:
|
|
candidates = [
|
|
"standard_allreduce_rmsnorm_fp8_quant",
|
|
"standard_allreduce_rmsnorm_fp8_quant_native_compiled",
|
|
]
|
|
elif "fp4_quant" in op_name:
|
|
candidates = [
|
|
"standard_allreduce_rmsnorm_fp4_quant",
|
|
"standard_allreduce_rmsnorm_fp4_quant_native_compiled",
|
|
]
|
|
else:
|
|
candidates = [
|
|
"standard_allreduce_rmsnorm",
|
|
"standard_allreduce_rmsnorm_native_compiled",
|
|
]
|
|
|
|
# Find the fastest among available candidates
|
|
fastest_time = float("inf")
|
|
fastest_baseline = None
|
|
|
|
for candidate in candidates:
|
|
if (
|
|
candidate in results_dict
|
|
and results_dict[candidate] != float("inf")
|
|
and results_dict[candidate] < fastest_time
|
|
):
|
|
fastest_time = results_dict[candidate]
|
|
fastest_baseline = candidate
|
|
|
|
return fastest_baseline
|
|
|
|
# Create dynamic baseline mapping
|
|
dynamic_baseline_mapping = {}
|
|
for op_name in results_dict:
|
|
if (
|
|
op_name.startswith("flashinfer_")
|
|
or op_name.startswith("standard_")
|
|
and not op_name.endswith("_native_compiled")
|
|
):
|
|
dynamic_baseline_mapping[op_name] = get_fastest_baseline(
|
|
op_name, results_dict
|
|
)
|
|
|
|
for op_name, time_ms in results_dict.items():
|
|
if time_ms == float("inf"):
|
|
speedup_str = "FAILED"
|
|
time_str = "FAILED"
|
|
else:
|
|
time_str = f"{time_ms:.3f}"
|
|
# Find the appropriate baseline for this operation
|
|
baseline_op = dynamic_baseline_mapping.get(op_name)
|
|
if baseline_op and baseline_op in results_dict:
|
|
baseline_time = results_dict[baseline_op]
|
|
if baseline_time != float("inf") and baseline_time > 0:
|
|
speedup = baseline_time / time_ms
|
|
speedup_str = f"{speedup:.2f}x"
|
|
else:
|
|
speedup_str = "N/A"
|
|
else:
|
|
# For baseline operations, determine if this is the fastest baseline
|
|
if op_name.endswith("_native_compiled") or (
|
|
op_name.startswith("standard_")
|
|
and not op_name.endswith("_native_compiled")
|
|
):
|
|
fastest_baseline = get_fastest_baseline(op_name, results_dict)
|
|
if fastest_baseline == op_name:
|
|
speedup_str = "baseline"
|
|
else:
|
|
if fastest_baseline and fastest_baseline in results_dict:
|
|
baseline_time = results_dict[fastest_baseline]
|
|
if baseline_time != float("inf") and baseline_time > 0:
|
|
speedup = baseline_time / time_ms
|
|
speedup_str = f"{speedup:.2f}x"
|
|
else:
|
|
speedup_str = "N/A"
|
|
else:
|
|
speedup_str = "N/A"
|
|
else:
|
|
speedup_str = "N/A"
|
|
|
|
prepared_results.append(
|
|
{
|
|
"operation": op_name,
|
|
"time_ms": time_ms,
|
|
"time_str": time_str,
|
|
"speedup_str": speedup_str,
|
|
}
|
|
)
|
|
|
|
return prepared_results
|
|
|
|
|
|
def print_results(
|
|
results_dict,
|
|
num_tokens,
|
|
hidden_dim,
|
|
dtype,
|
|
use_residual,
|
|
quant_modes,
|
|
input_size_mb,
|
|
):
|
|
"""Print benchmark results in a formatted table."""
|
|
print(f"\n{'=' * 80}")
|
|
print(
|
|
f"Results: num_tokens={num_tokens}, hidden_dim={hidden_dim} "
|
|
f"(input size: {input_size_mb:.2f} MB)"
|
|
)
|
|
print(
|
|
f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, "
|
|
f"quant_modes={','.join(sorted(list(quant_modes)))}"
|
|
)
|
|
print(f"{'=' * 80}")
|
|
print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}")
|
|
print(f"{'-' * 80}")
|
|
|
|
# Prepare results with speedup calculations
|
|
prepared_results = prepare_results_with_speedups(results_dict)
|
|
|
|
for result in prepared_results:
|
|
if result["time_ms"] == float("inf"):
|
|
time_display = result["time_str"]
|
|
else:
|
|
time_display = f"{result['time_ms']:.3f}"
|
|
|
|
print(
|
|
f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}"
|
|
)
|
|
|
|
|
|
def format_results_markdown(
|
|
all_results: list[dict], world_size: int, args: argparse.Namespace
|
|
) -> str:
|
|
"""Format all benchmark results as markdown."""
|
|
lines: list[str] = []
|
|
lines.append("# FlashInfer Fused Collective Operations Benchmark Results")
|
|
lines.append("")
|
|
lines.append(f"**World Size:** {world_size} ")
|
|
lines.append(f"**Hidden Dimension:** {args.hidden_dim} ")
|
|
lines.append(f"**Warmup Iterations:** {args.warmup} ")
|
|
lines.append(f"**Benchmark Trials:** {args.trials} ")
|
|
modes = ",".join(all_results[0]["quant_modes"]) if all_results else "N/A"
|
|
lines.append(f"**Quantization Modes:** {modes} ")
|
|
lines.append("")
|
|
lines.append("---")
|
|
lines.append("")
|
|
|
|
for entry in all_results:
|
|
num_tokens = entry["num_tokens"]
|
|
dtype = entry["dtype"]
|
|
use_residual = entry["use_residual"]
|
|
results_dict = entry["results"]
|
|
input_size_mb = entry["input_size_mb"]
|
|
residual_str = "with residual" if use_residual else "no residual"
|
|
|
|
lines.append(
|
|
f"## Configuration: num_tokens={num_tokens}, dtype={dtype}, {residual_str}"
|
|
)
|
|
lines.append(f"**Input Size:** {input_size_mb:.2f} MB")
|
|
lines.append("")
|
|
|
|
prepared = prepare_results_with_speedups(results_dict)
|
|
# Build DataFrame for markdown export
|
|
rows = [
|
|
{
|
|
"Operation": r["operation"].replace("_", " ").title(),
|
|
"Time (ms)": r["time_str"],
|
|
"Speedup": r["speedup_str"],
|
|
}
|
|
for r in prepared
|
|
]
|
|
df = pd.DataFrame(rows)
|
|
if df.empty:
|
|
lines.append("No results.")
|
|
else:
|
|
lines.append(df.to_markdown(index=False))
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def save_results_to_file(
|
|
all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int
|
|
):
|
|
"""Save benchmark results to markdown file (only on rank 0)."""
|
|
if rank != 0:
|
|
return
|
|
|
|
if not all_results:
|
|
logger.warning("No results to save")
|
|
return
|
|
|
|
output_path = args.output_file
|
|
|
|
try:
|
|
markdown_content = format_results_markdown(all_results, world_size, args)
|
|
|
|
with open(output_path, "a") as f:
|
|
f.write(markdown_content)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to save results to file: %s", e)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Benchmark fused collective operations"
|
|
)
|
|
parser.add_argument(
|
|
"--num-tokens",
|
|
type=int,
|
|
nargs="+",
|
|
default=[128, 512, 1024, 2048],
|
|
help="Numbers of tokens to test",
|
|
)
|
|
parser.add_argument(
|
|
"--hidden-dim", type=int, default=8192, help="Hidden dimension size"
|
|
)
|
|
parser.add_argument(
|
|
"--dtypes",
|
|
type=str,
|
|
nargs="+",
|
|
default=["bfloat16"],
|
|
choices=["float16", "bfloat16", "float32"],
|
|
help="Data types to test",
|
|
)
|
|
parser.add_argument(
|
|
"--no-residual",
|
|
action="store_true",
|
|
help="Skip residual connection tests",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--quant-modes",
|
|
type=str,
|
|
default="none,fp8,fp4",
|
|
help=(
|
|
"Comma-separated quantization modes to run: none, fp8, fp4. "
|
|
"Default: none,fp8,fp4"
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--warmup", type=int, default=5, help="Number of warmup iterations"
|
|
)
|
|
parser.add_argument(
|
|
"--trials", type=int, default=20, help="Number of benchmark trials"
|
|
)
|
|
parser.add_argument(
|
|
"--output-file",
|
|
type=str,
|
|
help="""Output file path for markdown results
|
|
(default: benchmark_results_<timestamp>.md)
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--no-oneshot",
|
|
action="store_true",
|
|
help="Skip oneshot benchmarks",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Check if running with torchrun (required for collective operations)
|
|
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
|
|
raise RuntimeError(
|
|
"Must run with torchrun for distributed benchmarking. "
|
|
"Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py"
|
|
)
|
|
|
|
# Initialize distributed environment
|
|
rank = int(os.environ["RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
torch.set_default_device(device)
|
|
|
|
init_distributed_environment()
|
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
|
|
# Validate world size (must be > 1 for collective operations)
|
|
if world_size <= 1:
|
|
raise ValueError(
|
|
"World size must be > 1 for collective operations benchmarking. "
|
|
f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1."
|
|
)
|
|
|
|
# Parse quantization modes
|
|
valid_quant_modes = {"none", "fp8", "fp4"}
|
|
raw_modes = [
|
|
m.strip().lower() for m in (args.quant_modes or "").split(",") if m.strip()
|
|
]
|
|
quant_modes = set(raw_modes) if raw_modes else {"none", "fp8", "fp4"}
|
|
invalid = sorted(list(quant_modes - valid_quant_modes))
|
|
if invalid:
|
|
raise ValueError(
|
|
f"Invalid --quant-modes entries: {','.join(invalid)}. "
|
|
f"Valid options are: {','.join(sorted(valid_quant_modes))}."
|
|
)
|
|
|
|
if rank == 0:
|
|
logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank)
|
|
logger.info("Quantization modes: %s", ",".join(sorted(list(quant_modes))))
|
|
if flashinfer_comm is not None:
|
|
logger.info(
|
|
"FlashInfer available - will benchmark fused operations",
|
|
)
|
|
else:
|
|
logger.info(
|
|
"FlashInfer not available - only benchmarking standard operations"
|
|
)
|
|
|
|
# Convert dtype strings to torch dtypes
|
|
dtype_map = {
|
|
"float16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
"float32": torch.float32,
|
|
}
|
|
dtypes = [dtype_map[dt] for dt in args.dtypes]
|
|
|
|
# Test configurations
|
|
residual_options = [True] if not args.no_residual else [False]
|
|
|
|
configs = list(itertools.product(args.num_tokens, dtypes, residual_options))
|
|
|
|
# Setup FlashInfer workspace if available
|
|
ipc_handles = None
|
|
allreduce_params = None
|
|
|
|
if flashinfer_comm is not None:
|
|
# Use the largest hidden dimension for workspace setup
|
|
max_num_token = _FI_MAX_SIZES.get(world_size) // (
|
|
args.hidden_dim * world_size * 2
|
|
)
|
|
|
|
ipc_handles, workspace_tensor = setup_flashinfer_workspace(
|
|
world_size, rank, args.hidden_dim, max_num_token
|
|
)
|
|
|
|
if workspace_tensor is not None:
|
|
allreduce_params = FlashInferFusedAllReduceParams(
|
|
rank=rank,
|
|
world_size=world_size,
|
|
max_token_num=max_num_token,
|
|
)
|
|
|
|
# Collect all results for markdown export
|
|
all_results = []
|
|
|
|
try:
|
|
# Run benchmarks
|
|
for num_tokens, dtype, use_residual in configs:
|
|
if rank == 0:
|
|
logger.info(
|
|
"\nTesting: num_tokens=%s, hidden_dim=%s, dtype=%s, residual=%s",
|
|
num_tokens,
|
|
args.hidden_dim,
|
|
dtype,
|
|
use_residual,
|
|
)
|
|
|
|
results = run_benchmarks(
|
|
num_tokens,
|
|
args.hidden_dim,
|
|
dtype,
|
|
use_residual,
|
|
allreduce_params,
|
|
quant_modes=quant_modes,
|
|
no_oneshot=args.no_oneshot,
|
|
)
|
|
|
|
# Store results for markdown export
|
|
if rank == 0:
|
|
# Calculate input size in MB
|
|
input_size_mb = (
|
|
num_tokens * args.hidden_dim * torch.finfo(dtype).bits
|
|
) / (8 * 1024 * 1024)
|
|
all_results.append(
|
|
{
|
|
"num_tokens": num_tokens,
|
|
"hidden_dim": args.hidden_dim,
|
|
"dtype": str(dtype).replace("torch.", ""),
|
|
"use_residual": use_residual,
|
|
"quant_modes": sorted(list(quant_modes)),
|
|
"input_size_mb": input_size_mb,
|
|
"results": results,
|
|
}
|
|
)
|
|
|
|
print_results(
|
|
results,
|
|
num_tokens,
|
|
args.hidden_dim,
|
|
dtype,
|
|
use_residual,
|
|
quant_modes,
|
|
input_size_mb,
|
|
)
|
|
|
|
# Save results to markdown file
|
|
if args.output_file and rank == 0:
|
|
save_results_to_file(all_results, world_size, args, rank)
|
|
|
|
finally:
|
|
# Cleanup
|
|
if ipc_handles is not None:
|
|
cleanup_flashinfer_workspace(ipc_handles)
|
|
|
|
dist.barrier()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|