mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 13:37:53 +08:00
[Kernels] Enable Torch Symmetric Memory All-Reduce By Default (#24111)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
bcbe2a4d9e
commit
1fdd5c42d7
486
benchmarks/kernels/benchmark_device_communicators.py
Normal file
486
benchmarks/kernels/benchmark_device_communicators.py
Normal file
@ -0,0 +1,486 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Benchmark script for device communicators:
|
||||||
|
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||||
|
and SymmMemCommunicator (multimem, two-shot).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
torchrun --nproc_per_node=2 benchmark_device_communicators.py
|
||||||
|
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||||
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Default sequence lengths to benchmark
|
||||||
|
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||||
|
|
||||||
|
# Fixed hidden size and dtype for all benchmarks
|
||||||
|
HIDDEN_SIZE = 8192
|
||||||
|
BENCHMARK_DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
# CUDA graph settings
|
||||||
|
CUDA_GRAPH_CAPTURE_CYCLES = 10
|
||||||
|
|
||||||
|
|
||||||
|
class CommunicatorBenchmark:
|
||||||
|
"""Benchmark class for testing device communicators."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
cpu_group: ProcessGroup,
|
||||||
|
sequence_lengths: list[int],
|
||||||
|
):
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
self.device = device
|
||||||
|
self.cpu_group = cpu_group
|
||||||
|
|
||||||
|
# Calculate max_size_override based on largest sequence length
|
||||||
|
max_seq_len = max(sequence_lengths)
|
||||||
|
max_tensor_elements = max_seq_len * HIDDEN_SIZE
|
||||||
|
self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1
|
||||||
|
|
||||||
|
# Initialize communicators
|
||||||
|
self.custom_allreduce = None
|
||||||
|
self.pynccl_comm = None
|
||||||
|
self.symm_mem_comm = None
|
||||||
|
self.symm_mem_comm_multimem = None
|
||||||
|
self.symm_mem_comm_two_shot = None
|
||||||
|
|
||||||
|
self._init_communicators()
|
||||||
|
|
||||||
|
def _init_communicators(self):
|
||||||
|
"""Initialize all available communicators."""
|
||||||
|
try:
|
||||||
|
self.custom_allreduce = CustomAllreduce(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
max_size=self.max_size_override,
|
||||||
|
)
|
||||||
|
if not self.custom_allreduce.disabled:
|
||||||
|
logger.info("Rank %s: CustomAllreduce initialized", self.rank)
|
||||||
|
else:
|
||||||
|
logger.info("Rank %s: CustomAllreduce disabled", self.rank)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e
|
||||||
|
)
|
||||||
|
self.custom_allreduce = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
|
group=self.cpu_group, device=self.device
|
||||||
|
)
|
||||||
|
if not self.pynccl_comm.disabled:
|
||||||
|
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||||
|
else:
|
||||||
|
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||||
|
self.pynccl_comm = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e
|
||||||
|
)
|
||||||
|
self.pynccl_comm = None
|
||||||
|
|
||||||
|
# Initialize variants for SymmMemCommunicator
|
||||||
|
try:
|
||||||
|
self.symm_mem_comm_multimem = SymmMemCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
force_multimem=True,
|
||||||
|
max_size_override=self.max_size_override,
|
||||||
|
)
|
||||||
|
if not self.symm_mem_comm_multimem.disabled:
|
||||||
|
logger.info(
|
||||||
|
"Rank %s: SymmMemCommunicator (multimem) initialized", self.rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.symm_mem_comm_multimem = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s",
|
||||||
|
self.rank,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self.symm_mem_comm_multimem = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.symm_mem_comm_two_shot = SymmMemCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
force_multimem=False,
|
||||||
|
max_size_override=self.max_size_override,
|
||||||
|
)
|
||||||
|
if not self.symm_mem_comm_two_shot.disabled:
|
||||||
|
logger.info(
|
||||||
|
"Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.symm_mem_comm_two_shot = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s",
|
||||||
|
self.rank,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self.symm_mem_comm_two_shot = None
|
||||||
|
|
||||||
|
def benchmark_allreduce(
|
||||||
|
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark allreduce operations for all available communicators."""
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Define communicators with their benchmark functions
|
||||||
|
communicators = []
|
||||||
|
|
||||||
|
if self.custom_allreduce is not None:
|
||||||
|
comm = self.custom_allreduce
|
||||||
|
# CustomAllreduce one-shot
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"ca_1stage",
|
||||||
|
lambda t, c=comm: c.custom_all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_custom_ar(t),
|
||||||
|
comm.capture(),
|
||||||
|
"1stage", # env variable value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# CustomAllreduce two-shot
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"ca_2stage",
|
||||||
|
lambda t, c=comm: c.custom_all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_custom_ar(t),
|
||||||
|
comm.capture(),
|
||||||
|
"2stage", # env variable value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pynccl_comm is not None:
|
||||||
|
comm = self.pynccl_comm
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"pynccl",
|
||||||
|
lambda t, c=comm: c.all_reduce(t),
|
||||||
|
lambda t: True, # Always available if initialized
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.symm_mem_comm_multimem is not None:
|
||||||
|
comm = self.symm_mem_comm_multimem
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"symm_mem_multimem",
|
||||||
|
lambda t, c=comm: c.all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.symm_mem_comm_two_shot is not None:
|
||||||
|
comm = self.symm_mem_comm_two_shot
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"symm_mem_two_shot",
|
||||||
|
lambda t, c=comm: c.all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Benchmark each communicator
|
||||||
|
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||||
|
# Set environment variable if needed
|
||||||
|
if env_var is not None:
|
||||||
|
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||||
|
else:
|
||||||
|
# Clear the environment variable to avoid interference
|
||||||
|
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||||
|
|
||||||
|
latency = self.benchmark_allreduce_single(
|
||||||
|
sequence_length,
|
||||||
|
allreduce_fn,
|
||||||
|
should_use_fn,
|
||||||
|
context,
|
||||||
|
num_warmup,
|
||||||
|
num_trials,
|
||||||
|
)
|
||||||
|
if latency is not None:
|
||||||
|
results[name] = latency
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def benchmark_allreduce_single(
|
||||||
|
self,
|
||||||
|
sequence_length: int,
|
||||||
|
allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]],
|
||||||
|
should_use_fn: Callable[[torch.Tensor], bool],
|
||||||
|
context,
|
||||||
|
num_warmup: int,
|
||||||
|
num_trials: int,
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""Benchmark method with CUDA graph optimization."""
|
||||||
|
try:
|
||||||
|
# Create test tensor (2D: sequence_length x hidden_size)
|
||||||
|
tensor = torch.randn(
|
||||||
|
sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device
|
||||||
|
)
|
||||||
|
if not should_use_fn(tensor):
|
||||||
|
return None
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
graph_input = tensor.clone()
|
||||||
|
|
||||||
|
# Warmup before capture
|
||||||
|
for _ in range(3):
|
||||||
|
allreduce_fn(graph_input)
|
||||||
|
|
||||||
|
# Capture the graph using context manager
|
||||||
|
with context:
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||||
|
allreduce_fn(graph_input)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_trials):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
|
||||||
|
return (
|
||||||
|
(end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("CUDA graph benchmark failed: %s", e)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"CUDA graph benchmark failed for communicator: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_speedup_info(comm_results: dict[str, float]) -> str:
|
||||||
|
"""Calculate speedup information for a single tensor size."""
|
||||||
|
if not comm_results:
|
||||||
|
return "N/A"
|
||||||
|
|
||||||
|
# Find the fastest communicator
|
||||||
|
fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k])
|
||||||
|
fastest_time = comm_results[fastest_comm]
|
||||||
|
|
||||||
|
# Calculate speedup vs PyNccl if available
|
||||||
|
if "pynccl" in comm_results:
|
||||||
|
pynccl_time = comm_results["pynccl"]
|
||||||
|
speedup = pynccl_time / fastest_time
|
||||||
|
return f"{fastest_comm} ({speedup:.2f}x)"
|
||||||
|
else:
|
||||||
|
return f"{fastest_comm} (N/A)"
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(
|
||||||
|
results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int
|
||||||
|
):
|
||||||
|
"""Print benchmark results in a formatted table."""
|
||||||
|
|
||||||
|
print(f"\n{'=' * 130}")
|
||||||
|
print("Device Communicator Benchmark Results")
|
||||||
|
print(
|
||||||
|
f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, "
|
||||||
|
f"Hidden Size: {HIDDEN_SIZE}"
|
||||||
|
)
|
||||||
|
print(f"{'=' * 130}")
|
||||||
|
|
||||||
|
# Get all communicator names
|
||||||
|
all_comms = set()
|
||||||
|
for size_results in results.values():
|
||||||
|
all_comms.update(size_results.keys())
|
||||||
|
|
||||||
|
all_comms = sorted(list(all_comms))
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
header = f"{'Tensor Shape':<20}{'Tensor Size':<15}"
|
||||||
|
for comm in all_comms:
|
||||||
|
header += f"{comm:<20}"
|
||||||
|
header += f"{'Best (Speedup vs PyNccl)':<30}"
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
# Print results for each sequence length
|
||||||
|
for seq_len in sequence_lengths:
|
||||||
|
if seq_len in results:
|
||||||
|
# Calculate tensor size in elements and bytes
|
||||||
|
tensor_elements = seq_len * HIDDEN_SIZE
|
||||||
|
tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize
|
||||||
|
|
||||||
|
# Format tensor size (MB)
|
||||||
|
tensor_size_mb = tensor_bytes / (1024 * 1024)
|
||||||
|
tensor_size_str = f"{tensor_size_mb:.2f} MB"
|
||||||
|
|
||||||
|
# Format tensor shape
|
||||||
|
tensor_shape = f"({seq_len}, {HIDDEN_SIZE})"
|
||||||
|
|
||||||
|
row = f"{tensor_shape:<20}{tensor_size_str:<15}"
|
||||||
|
for comm in all_comms:
|
||||||
|
if comm in results[seq_len]:
|
||||||
|
row += f"{results[seq_len][comm]:<20.3f}"
|
||||||
|
else:
|
||||||
|
row += f"{'N/A':<20}"
|
||||||
|
|
||||||
|
# Calculate speedup information
|
||||||
|
speedup_info = _calculate_speedup_info(results[seq_len])
|
||||||
|
row += f"{speedup_info:<30}"
|
||||||
|
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
print(f"{'=' * 130}")
|
||||||
|
print("All times are in milliseconds (ms) per allreduce operation")
|
||||||
|
print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = FlexibleArgumentParser(description="Benchmark device communicators")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sequence-lengths",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_SEQUENCE_LENGTHS,
|
||||||
|
help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-warmup", type=int, default=5, help="Number of warmup iterations"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-trials", type=int, default=50, help="Number of benchmark trials"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--output-json", type=str, help="Output results to JSON file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Initialize distributed
|
||||||
|
if not dist.is_initialized():
|
||||||
|
dist.init_process_group(backend="gloo")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
# Get CPU process group
|
||||||
|
cpu_group = dist.new_group(backend="gloo")
|
||||||
|
|
||||||
|
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
|
||||||
|
# in symm_mem and custom_all_reduce for benchmark
|
||||||
|
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||||
|
|
||||||
|
# Initialize benchmark
|
||||||
|
benchmark = CommunicatorBenchmark(
|
||||||
|
rank, world_size, device, cpu_group, args.sequence_lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run benchmarks
|
||||||
|
all_results = {}
|
||||||
|
|
||||||
|
for seq_len in args.sequence_lengths:
|
||||||
|
if rank == 0:
|
||||||
|
logger.info(
|
||||||
|
"Benchmarking sequence length: %s (tensor shape: %s x %s)",
|
||||||
|
seq_len,
|
||||||
|
seq_len,
|
||||||
|
HIDDEN_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = benchmark.benchmark_allreduce(
|
||||||
|
sequence_length=seq_len,
|
||||||
|
num_warmup=args.num_warmup,
|
||||||
|
num_trials=args.num_trials,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_results[seq_len] = results
|
||||||
|
|
||||||
|
# Synchronize between ranks
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# Print results (only rank 0)
|
||||||
|
if rank == 0:
|
||||||
|
print_results(all_results, args.sequence_lengths, world_size)
|
||||||
|
|
||||||
|
# Save to JSON if requested
|
||||||
|
if args.output_json:
|
||||||
|
# Add speedup information to results
|
||||||
|
enhanced_results = {}
|
||||||
|
for seq_len, comm_results in all_results.items():
|
||||||
|
enhanced_results[seq_len] = {
|
||||||
|
"timings": comm_results,
|
||||||
|
"speedup_info": _calculate_speedup_info(comm_results),
|
||||||
|
}
|
||||||
|
|
||||||
|
output_data = {
|
||||||
|
"world_size": world_size,
|
||||||
|
"dtype": str(BENCHMARK_DTYPE),
|
||||||
|
"hidden_size": HIDDEN_SIZE,
|
||||||
|
"sequence_lengths": args.sequence_lengths,
|
||||||
|
"num_warmup": args.num_warmup,
|
||||||
|
"num_trials": args.num_trials,
|
||||||
|
"cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES,
|
||||||
|
"results": enhanced_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.output_json, "w") as f:
|
||||||
|
json.dump(output_data, f, indent=2)
|
||||||
|
|
||||||
|
logger.info("Results saved to %s", args.output_json)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if cpu_group != dist.group.WORLD:
|
||||||
|
dist.destroy_process_group(cpu_group)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -15,6 +15,8 @@ typedef __hip_bfloat16 nv_bfloat16;
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
#define CUDACHECK(cmd) \
|
#define CUDACHECK(cmd) \
|
||||||
@ -555,22 +557,47 @@ class CustomAllreduce {
|
|||||||
size /= d;
|
size /= d;
|
||||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||||
|
|
||||||
|
// Check environment variable once
|
||||||
|
const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO");
|
||||||
|
bool force_1stage = false;
|
||||||
|
bool force_2stage = false;
|
||||||
|
if (env_algo != nullptr) {
|
||||||
|
if (std::strcmp(env_algo, "1stage") == 0 ||
|
||||||
|
std::strcmp(env_algo, "oneshot") == 0) {
|
||||||
|
force_1stage = true;
|
||||||
|
} else if (std::strcmp(env_algo, "2stage") == 0 ||
|
||||||
|
std::strcmp(env_algo, "twoshot") == 0) {
|
||||||
|
force_2stage = true;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) +
|
||||||
|
". Valid values: 1stage, oneshot, 2stage, twoshot");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define KL(ngpus, name) \
|
#define KL(ngpus, name) \
|
||||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||||
rank_, size);
|
rank_, size);
|
||||||
#define REDUCE_CASE(ngpus) \
|
#define REDUCE_CASE(ngpus) \
|
||||||
case ngpus: { \
|
case ngpus: { \
|
||||||
if (world_size_ == 2) { \
|
if (force_1stage) { \
|
||||||
KL(ngpus, cross_device_reduce_1stage); \
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
} else if (fully_connected_) { \
|
} else if (force_2stage) { \
|
||||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
KL(ngpus, cross_device_reduce_2stage); \
|
||||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
} else { \
|
||||||
KL(ngpus, cross_device_reduce_1stage); \
|
if (world_size_ == 2) { \
|
||||||
} else { \
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
KL(ngpus, cross_device_reduce_2stage); \
|
} else if (fully_connected_) { \
|
||||||
} \
|
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||||
} \
|
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||||
break; \
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
|
} else { \
|
||||||
|
KL(ngpus, cross_device_reduce_2stage); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (world_size_) {
|
switch (world_size_) {
|
||||||
|
|||||||
@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
|||||||
"10.0": {
|
"10.0": {
|
||||||
2: 2 * MiB, # 2 MB
|
2: 2 * MiB, # 2 MB
|
||||||
4: 2 * MiB, # 2 MB
|
4: 2 * MiB, # 2 MB
|
||||||
6: 2 * MiB, # 2 MB
|
6: 1 * MiB, # 1 MB
|
||||||
8: 2 * MiB, # 2 MB
|
8: 1 * MiB, # 1 MB
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
self.ca_comm: Optional[CustomAllreduce] = None
|
self.ca_comm: Optional[CustomAllreduce] = None
|
||||||
self.qr_comm: Optional[QuickAllReduce] = None
|
self.qr_comm: Optional[QuickAllReduce] = None
|
||||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||||
|
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
||||||
|
self.symm_mem_comm = SymmMemCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
if use_custom_allreduce and self.world_size > 1:
|
if use_custom_allreduce and self.world_size > 1:
|
||||||
# Initialize a custom fast all-reduce implementation.
|
# Initialize a custom fast all-reduce implementation.
|
||||||
self.ca_comm = CustomAllreduce(
|
self.ca_comm = CustomAllreduce(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
symm_mem_enabled=(self.symm_mem_comm is not None
|
||||||
|
and not self.symm_mem_comm.disabled),
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
# currently be an MI300 series.
|
# currently be an MI300 series.
|
||||||
self.qr_comm = QuickAllReduce(group=self.cpu_group,
|
self.qr_comm = QuickAllReduce(group=self.cpu_group,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
|
||||||
self.symm_mem_comm = SymmMemCommunicator(
|
|
||||||
group=self.cpu_group,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_all2all:
|
if self.use_all2all:
|
||||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||||
|
|||||||
@ -54,7 +54,8 @@ class CustomAllreduce:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
group: ProcessGroup,
|
group: ProcessGroup,
|
||||||
device: Union[int, str, torch.device],
|
device: Union[int, str, torch.device],
|
||||||
max_size=8192 * 1024) -> None:
|
max_size=8192 * 1024,
|
||||||
|
symm_mem_enabled=False) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
group: the process group to work on. If None, it will use the
|
group: the process group to work on. If None, it will use the
|
||||||
@ -111,7 +112,7 @@ class CustomAllreduce:
|
|||||||
self.device = device
|
self.device = device
|
||||||
device_capability = current_platform.get_device_capability(
|
device_capability = current_platform.get_device_capability(
|
||||||
).as_version_str()
|
).as_version_str()
|
||||||
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
if (current_platform.is_cuda() and symm_mem_enabled
|
||||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
|
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
|
||||||
max_size = min(
|
max_size = min(
|
||||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
|
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
|
||||||
|
|||||||
@ -27,8 +27,13 @@ class SymmMemCommunicator:
|
|||||||
"10.0": [6, 8],
|
"10.0": [6, 8],
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, group: ProcessGroup, device: Union[int, str,
|
def __init__(
|
||||||
torch.device]):
|
self,
|
||||||
|
group: ProcessGroup,
|
||||||
|
device: Union[int, str, torch.device],
|
||||||
|
# add options for testing
|
||||||
|
force_multimem: Optional[bool] = None,
|
||||||
|
max_size_override: Optional[int] = None):
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
|
|
||||||
if not symm_mem_available:
|
if not symm_mem_available:
|
||||||
@ -64,8 +69,17 @@ class SymmMemCommunicator:
|
|||||||
self.world_size,
|
self.world_size,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
# Use override max_size if provided, otherwise use default
|
||||||
self.world_size]
|
if max_size_override is not None:
|
||||||
|
self.max_size = max_size_override
|
||||||
|
logger.info(
|
||||||
|
"SymmMemCommunicator: Using override max_size: %s bytes",
|
||||||
|
self.max_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[
|
||||||
|
self.device_capability][self.world_size]
|
||||||
|
|
||||||
self.buffer = torch_symm_mem.empty(
|
self.buffer = torch_symm_mem.empty(
|
||||||
self.max_size // self.dtype.itemsize,
|
self.max_size // self.dtype.itemsize,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -76,6 +90,7 @@ class SymmMemCommunicator:
|
|||||||
logger.warning("SymmMemCommunicator: symmetric memory "
|
logger.warning("SymmMemCommunicator: symmetric memory "
|
||||||
"multicast operations are not supported.")
|
"multicast operations are not supported.")
|
||||||
return
|
return
|
||||||
|
self.force_multimem = force_multimem
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||||
@ -98,8 +113,18 @@ class SymmMemCommunicator:
|
|||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(inp)
|
out = torch.empty_like(inp)
|
||||||
self.buffer[:inp.numel()].copy_(inp.view(-1))
|
self.buffer[:inp.numel()].copy_(inp.view(-1))
|
||||||
if self.world_size in self._WORLD_SIZES_MULTIMEM[
|
|
||||||
self.device_capability]:
|
# Determine which algorithm to use
|
||||||
|
use_multimem = False
|
||||||
|
if self.force_multimem is not None:
|
||||||
|
# Test override: use forced setting
|
||||||
|
use_multimem = self.force_multimem
|
||||||
|
else:
|
||||||
|
# Normal logic: use multimem for supported world sizes
|
||||||
|
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[
|
||||||
|
self.device_capability]
|
||||||
|
|
||||||
|
if use_multimem:
|
||||||
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
|
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
|
||||||
"sum",
|
"sum",
|
||||||
self.group.group_name)
|
self.group.group_name)
|
||||||
|
|||||||
@ -166,7 +166,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
VLLM_HAS_FLASHINFER_CUBIN: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||||
@ -1203,7 +1203,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
|
|
||||||
# Whether to use pytorch symmetric memory for allreduce
|
# Whether to use pytorch symmetric memory for allreduce
|
||||||
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
||||||
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
|
||||||
|
|
||||||
# Allows vllm to find tuned config under customized folder
|
# Allows vllm to find tuned config under customized folder
|
||||||
"VLLM_TUNED_CONFIG_FOLDER":
|
"VLLM_TUNED_CONFIG_FOLDER":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user