diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py new file mode 100644 index 0000000000000..a61c17edc1e28 --- /dev/null +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark script for device communicators: +CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, +and SymmMemCommunicator (multimem, two-shot). + +Usage: + torchrun --nproc_per_node= benchmark_device_communicators.py [options] + +Example: + torchrun --nproc_per_node=2 benchmark_device_communicators.py + --sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100 +""" + +import json +import os +import time +from contextlib import nullcontext +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + +# Default sequence lengths to benchmark +DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192] + +# Fixed hidden size and dtype for all benchmarks +HIDDEN_SIZE = 8192 +BENCHMARK_DTYPE = torch.bfloat16 + +# CUDA graph settings +CUDA_GRAPH_CAPTURE_CYCLES = 10 + + +class CommunicatorBenchmark: + """Benchmark class for testing device communicators.""" + + def __init__( + self, + rank: int, + world_size: int, + device: torch.device, + cpu_group: ProcessGroup, + sequence_lengths: list[int], + ): + self.rank = rank + self.world_size = world_size + self.device = device + self.cpu_group = cpu_group + + # Calculate max_size_override based on largest sequence length + max_seq_len = max(sequence_lengths) + max_tensor_elements = max_seq_len * HIDDEN_SIZE + self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1 + + # Initialize communicators + self.custom_allreduce = None + self.pynccl_comm = None + self.symm_mem_comm = None + self.symm_mem_comm_multimem = None + self.symm_mem_comm_two_shot = None + + self._init_communicators() + + def _init_communicators(self): + """Initialize all available communicators.""" + try: + self.custom_allreduce = CustomAllreduce( + group=self.cpu_group, + device=self.device, + max_size=self.max_size_override, + ) + if not self.custom_allreduce.disabled: + logger.info("Rank %s: CustomAllreduce initialized", self.rank) + else: + logger.info("Rank %s: CustomAllreduce disabled", self.rank) + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e + ) + self.custom_allreduce = None + + try: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, device=self.device + ) + if not self.pynccl_comm.disabled: + logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + else: + logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) + self.pynccl_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e + ) + self.pynccl_comm = None + + # Initialize variants for SymmMemCommunicator + try: + self.symm_mem_comm_multimem = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=True, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_multimem.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (multimem) initialized", self.rank + ) + else: + self.symm_mem_comm_multimem = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s", + self.rank, + e, + ) + self.symm_mem_comm_multimem = None + + try: + self.symm_mem_comm_two_shot = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=False, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_two_shot.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank + ) + else: + self.symm_mem_comm_two_shot = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s", + self.rank, + e, + ) + self.symm_mem_comm_two_shot = None + + def benchmark_allreduce( + self, sequence_length: int, num_warmup: int, num_trials: int + ) -> dict[str, float]: + """Benchmark allreduce operations for all available communicators.""" + + results = {} + + # Define communicators with their benchmark functions + communicators = [] + + if self.custom_allreduce is not None: + comm = self.custom_allreduce + # CustomAllreduce one-shot + communicators.append( + ( + "ca_1stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "1stage", # env variable value + ) + ) + # CustomAllreduce two-shot + communicators.append( + ( + "ca_2stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "2stage", # env variable value + ) + ) + + if self.pynccl_comm is not None: + comm = self.pynccl_comm + communicators.append( + ( + "pynccl", + lambda t, c=comm: c.all_reduce(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_multimem is not None: + comm = self.symm_mem_comm_multimem + communicators.append( + ( + "symm_mem_multimem", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_two_shot is not None: + comm = self.symm_mem_comm_two_shot + communicators.append( + ( + "symm_mem_two_shot", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + # Benchmark each communicator + for name, allreduce_fn, should_use_fn, context, env_var in communicators: + # Set environment variable if needed + if env_var is not None: + os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var + else: + # Clear the environment variable to avoid interference + os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None) + + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + + return results + + def benchmark_allreduce_single( + self, + sequence_length: int, + allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]], + should_use_fn: Callable[[torch.Tensor], bool], + context, + num_warmup: int, + num_trials: int, + ) -> Optional[float]: + """Benchmark method with CUDA graph optimization.""" + try: + # Create test tensor (2D: sequence_length x hidden_size) + tensor = torch.randn( + sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device + ) + if not should_use_fn(tensor): + return None + + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + graph_input = tensor.clone() + + # Warmup before capture + for _ in range(3): + allreduce_fn(graph_input) + + # Capture the graph using context manager + with context: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): + allreduce_fn(graph_input) + + torch.cuda.synchronize() + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(num_trials): + graph.replay() + torch.cuda.synchronize() + + end_time = time.perf_counter() + + # Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES + return ( + (end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000 + ) + + except Exception as e: + logger.error("CUDA graph benchmark failed: %s", e) + raise RuntimeError( + f"CUDA graph benchmark failed for communicator: {e}" + ) from e + + +def _calculate_speedup_info(comm_results: dict[str, float]) -> str: + """Calculate speedup information for a single tensor size.""" + if not comm_results: + return "N/A" + + # Find the fastest communicator + fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k]) + fastest_time = comm_results[fastest_comm] + + # Calculate speedup vs PyNccl if available + if "pynccl" in comm_results: + pynccl_time = comm_results["pynccl"] + speedup = pynccl_time / fastest_time + return f"{fastest_comm} ({speedup:.2f}x)" + else: + return f"{fastest_comm} (N/A)" + + +def print_results( + results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int +): + """Print benchmark results in a formatted table.""" + + print(f"\n{'=' * 130}") + print("Device Communicator Benchmark Results") + print( + f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, " + f"Hidden Size: {HIDDEN_SIZE}" + ) + print(f"{'=' * 130}") + + # Get all communicator names + all_comms = set() + for size_results in results.values(): + all_comms.update(size_results.keys()) + + all_comms = sorted(list(all_comms)) + + # Print header + header = f"{'Tensor Shape':<20}{'Tensor Size':<15}" + for comm in all_comms: + header += f"{comm:<20}" + header += f"{'Best (Speedup vs PyNccl)':<30}" + print(header) + print("-" * len(header)) + + # Print results for each sequence length + for seq_len in sequence_lengths: + if seq_len in results: + # Calculate tensor size in elements and bytes + tensor_elements = seq_len * HIDDEN_SIZE + tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize + + # Format tensor size (MB) + tensor_size_mb = tensor_bytes / (1024 * 1024) + tensor_size_str = f"{tensor_size_mb:.2f} MB" + + # Format tensor shape + tensor_shape = f"({seq_len}, {HIDDEN_SIZE})" + + row = f"{tensor_shape:<20}{tensor_size_str:<15}" + for comm in all_comms: + if comm in results[seq_len]: + row += f"{results[seq_len][comm]:<20.3f}" + else: + row += f"{'N/A':<20}" + + # Calculate speedup information + speedup_info = _calculate_speedup_info(results[seq_len]) + row += f"{speedup_info:<30}" + + print(row) + + print(f"{'=' * 130}") + print("All times are in milliseconds (ms) per allreduce operation") + print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)") + + +def main(): + parser = FlexibleArgumentParser(description="Benchmark device communicators") + + parser.add_argument( + "--sequence-lengths", + type=int, + nargs="+", + default=DEFAULT_SEQUENCE_LENGTHS, + help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)", + ) + + parser.add_argument( + "--num-warmup", type=int, default=5, help="Number of warmup iterations" + ) + + parser.add_argument( + "--num-trials", type=int, default=50, help="Number of benchmark trials" + ) + + parser.add_argument("--output-json", type=str, help="Output results to JSON file") + + args = parser.parse_args() + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + # Get CPU process group + cpu_group = dist.new_group(backend="gloo") + + # Disable USE_SYMM_MEM to avoid affecting the max_sizes + # in symm_mem and custom_all_reduce for benchmark + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + # Initialize benchmark + benchmark = CommunicatorBenchmark( + rank, world_size, device, cpu_group, args.sequence_lengths + ) + + # Run benchmarks + all_results = {} + + for seq_len in args.sequence_lengths: + if rank == 0: + logger.info( + "Benchmarking sequence length: %s (tensor shape: %s x %s)", + seq_len, + seq_len, + HIDDEN_SIZE, + ) + + results = benchmark.benchmark_allreduce( + sequence_length=seq_len, + num_warmup=args.num_warmup, + num_trials=args.num_trials, + ) + + all_results[seq_len] = results + + # Synchronize between ranks + dist.barrier() + + # Print results (only rank 0) + if rank == 0: + print_results(all_results, args.sequence_lengths, world_size) + + # Save to JSON if requested + if args.output_json: + # Add speedup information to results + enhanced_results = {} + for seq_len, comm_results in all_results.items(): + enhanced_results[seq_len] = { + "timings": comm_results, + "speedup_info": _calculate_speedup_info(comm_results), + } + + output_data = { + "world_size": world_size, + "dtype": str(BENCHMARK_DTYPE), + "hidden_size": HIDDEN_SIZE, + "sequence_lengths": args.sequence_lengths, + "num_warmup": args.num_warmup, + "num_trials": args.num_trials, + "cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES, + "results": enhanced_results, + } + + with open(args.output_json, "w") as f: + json.dump(output_data, f, indent=2) + + logger.info("Results saved to %s", args.output_json) + + # Cleanup + if cpu_group != dist.group.WORLD: + dist.destroy_process_group(cpu_group) + + +if __name__ == "__main__": + main() diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b4597765..58926f6429dd3 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -15,6 +15,8 @@ typedef __hip_bfloat16 nv_bfloat16; #include #include #include +#include +#include namespace vllm { #define CUDACHECK(cmd) \ @@ -555,22 +557,47 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); + + // Check environment variable once + const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO"); + bool force_1stage = false; + bool force_2stage = false; + if (env_algo != nullptr) { + if (std::strcmp(env_algo, "1stage") == 0 || + std::strcmp(env_algo, "oneshot") == 0) { + force_1stage = true; + } else if (std::strcmp(env_algo, "2stage") == 0 || + std::strcmp(env_algo, "twoshot") == 0) { + force_2stage = true; + } else { + throw std::runtime_error( + "Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) + + ". Valid values: 1stage, oneshot, 2stage, twoshot"); + } + } + #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); -#define REDUCE_CASE(ngpus) \ - case ngpus: { \ - if (world_size_ == 2) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else if (fully_connected_) { \ - if ((world_size_ <= 4 && bytes < 512 * 1024) || \ - (world_size_ <= 8 && bytes < 256 * 1024)) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else { \ - KL(ngpus, cross_device_reduce_2stage); \ - } \ - } \ - break; \ +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (force_1stage) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (force_2stage) { \ + KL(ngpus, cross_device_reduce_2stage); \ + } else { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (fully_connected_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + } \ + break; \ } switch (world_size_) { diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 5c64e7d5c4ba3..805a88854b77c 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = { "10.0": { 2: 2 * MiB, # 2 MB 4: 2 * MiB, # 2 MB - 6: 2 * MiB, # 2 MB - 8: 2 * MiB, # 2 MB + 6: 1 * MiB, # 1 MB + 8: 1 * MiB, # 1 MB } } diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index eef3f9f75f9f1..78c90b006ffc8 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, + symm_mem_enabled=(self.symm_mem_comm is not None + and not self.symm_mem_comm.disabled), ) if current_platform.is_rocm(): @@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase): # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) - if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): - self.symm_mem_comm = SymmMemCommunicator( - group=self.cpu_group, - device=self.device, - ) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c8cc35f99785c..3cc4bbb258244 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -54,7 +54,8 @@ class CustomAllreduce: def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + max_size=8192 * 1024, + symm_mem_enabled=False) -> None: """ Args: group: the process group to work on. If None, it will use the @@ -111,7 +112,7 @@ class CustomAllreduce: self.device = device device_capability = current_platform.get_device_capability( ).as_version_str() - if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + if (current_platform.is_cuda() and symm_mem_enabled and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): max_size = min( CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index d907e1b833d04..09012d16978d9 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -27,8 +27,13 @@ class SymmMemCommunicator: "10.0": [6, 8], } - def __init__(self, group: ProcessGroup, device: Union[int, str, - torch.device]): + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + # add options for testing + force_multimem: Optional[bool] = None, + max_size_override: Optional[int] = None): self.disabled = True if not symm_mem_available: @@ -64,8 +69,17 @@ class SymmMemCommunicator: self.world_size, ) return - self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ - self.world_size] + # Use override max_size if provided, otherwise use default + if max_size_override is not None: + self.max_size = max_size_override + logger.info( + "SymmMemCommunicator: Using override max_size: %s bytes", + self.max_size, + ) + else: + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability][self.world_size] + self.buffer = torch_symm_mem.empty( self.max_size // self.dtype.itemsize, device=self.device, @@ -76,6 +90,7 @@ class SymmMemCommunicator: logger.warning("SymmMemCommunicator: symmetric memory " "multicast operations are not supported.") return + self.force_multimem = force_multimem self.disabled = False def should_use_symm_mem(self, inp: torch.Tensor): @@ -98,8 +113,18 @@ class SymmMemCommunicator: if out is None: out = torch.empty_like(inp) self.buffer[:inp.numel()].copy_(inp.view(-1)) - if self.world_size in self._WORLD_SIZES_MULTIMEM[ - self.device_capability]: + + # Determine which algorithm to use + use_multimem = False + if self.force_multimem is not None: + # Test override: use forced setting + use_multimem = self.force_multimem + else: + # Normal logic: use multimem for supported world sizes + use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability] + + if use_multimem: torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], "sum", self.group.group_name) diff --git a/vllm/envs.py b/vllm/envs.py index 184d339089a21..d7956a2adff68 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -166,7 +166,7 @@ if TYPE_CHECKING: VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False @@ -1203,7 +1203,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER":