diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c4ea4b675649c..aef6d709722fa 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -164,6 +164,7 @@ steps: - tests/v1/test_internal_lb_dp.py - tests/v1/test_hybrid_lb_dp.py - tests/v1/engine/test_engine_core_client.py + - tests/distributed/test_symm_mem_allreduce.py commands: # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py @@ -188,6 +189,7 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_events.py + - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference @@ -329,6 +331,8 @@ steps: - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py + - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 - label: Platform Tests (CUDA) # 4min timeout_in_minutes: 15 @@ -1037,3 +1041,4 @@ steps: num_gpus: 2 commands: - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index ed3679b66f805..b333ba9cd8e99 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -103,10 +103,15 @@ start_server() { VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \ vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & fi + local server_pid=$! # wait for 10 minutes... server_started=0 for i in {1..60}; do + # This line checks whether the server is still alive or not, + # since that we should always have permission to send signal to the server process. + kill -0 $server_pid 2> /dev/null || break + RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) if [[ "$STATUS_CODE" -eq 200 ]]; then @@ -118,7 +123,7 @@ start_server() { done if (( ! server_started )); then - echo "server did not start within 10 minutes. Please check server log at $vllm_log". + echo "server did not start within 10 minutes or crashed. Please check server log at $vllm_log". return 1 else return 0 diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index a5a5b52f60397..02f8c593392c4 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.utils import FlexibleArgumentParser, cdiv @@ -158,7 +158,7 @@ def bench_fp8( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) ), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm( a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) ), "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index a61c17edc1e28..4cbdde5a5b2ca 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -7,6 +7,10 @@ Benchmark script for device communicators: CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, and SymmMemCommunicator (multimem, two-shot). +for NCCL symmetric memory you need to set the environment variables +NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does +not use fast NVLS implementation for all reduce. + Usage: torchrun --nproc_per_node= benchmark_device_communicators.py [options] @@ -26,7 +30,13 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce -from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + register_nccl_symmetric_ops, +) +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser @@ -98,6 +108,7 @@ class CommunicatorBenchmark: ) if not self.pynccl_comm.disabled: logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + register_nccl_symmetric_ops(self.pynccl_comm) else: logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) self.pynccl_comm = None @@ -194,6 +205,15 @@ class CommunicatorBenchmark: None, # no env variable needed ) ) + communicators.append( + ( + "pynccl-symm", + lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) if self.symm_mem_comm_multimem is not None: comm = self.symm_mem_comm_multimem @@ -271,7 +291,9 @@ class CommunicatorBenchmark: # Capture the graph using context manager with context: graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): + graph_pool = torch.cuda.graph_pool_handle() + set_graph_pool_id(graph_pool) + with torch.cuda.graph(graph, pool=graph_pool): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index d4648c18f31d5..0aace571064a0 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -9,6 +9,9 @@ import torch from tabulate import tabulate from vllm import _custom_ops as ops +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import ( @@ -31,6 +34,8 @@ def run_benchmark( kv_cache_dtype: str, kv_cache_layout: str, num_iters: int, + implementation: str, + benchmark_mode: str, device: str = "cuda", ) -> float: """Return latency (seconds) for given num_tokens.""" @@ -38,6 +43,14 @@ def run_benchmark( if kv_cache_dtype == "fp8" and head_size % 16: raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + if implementation not in ("cuda", "triton"): + raise ValueError( + f"Unsupported implementation: {implementation}. " + "Only 'cuda' and 'triton' are supported." + ) + if implementation == "triton" and kv_cache_layout == "HND": + return float("nan") # Triton does not support HND layout yet. + current_platform.seed_everything(42) torch.set_default_device(device) @@ -65,27 +78,49 @@ def run_benchmark( cache_layout=kv_cache_layout, ) key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches # compute per-kernel scaling factors for fp8 conversion (if used). k_scale = (key.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32) + if implementation == "cuda": + function_under_test = lambda: ops.reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + function_under_test = lambda: triton_reshape_and_cache_flash( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + def run_cuda_benchmark(n_iters: int) -> float: nonlocal key, value, key_cache, value_cache, slot_mapping torch.cuda.synchronize() start = time.perf_counter() for _ in range(n_iters): - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - torch.cuda.synchronize() + function_under_test() + torch.cuda.synchronize() end = time.perf_counter() return (end - start) / n_iters @@ -116,10 +151,16 @@ def main(args): kv_cache_dtype=args.kv_cache_dtype, kv_cache_layout=layout, num_iters=args.iters, + implementation=args.implementation, + benchmark_mode=args.mode, device="cuda", ) rows.append([n_tok, layout, f"{lat * 1e6:.3f}"]) + print( + f"Benchmark results for implementation {args.implementation}" + f" (measuring with {args.mode}):" + ) print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"])) @@ -151,6 +192,21 @@ if __name__ == "__main__": ) parser.add_argument("--iters", type=int, default=100) + + parser.add_argument( + "--implementation", + type=str, + choices=["cuda", "triton"], + default="cuda", + ) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index b99c2099f2c38..b3c3742825de7 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -10,7 +10,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul, + w8a8_triton_block_scaled_mm, ) from vllm.triton_utils import triton from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 @@ -59,7 +59,7 @@ def benchmark_shape(m: int, # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_block_fp8_matmul(A_vllm, + return w8a8_triton_block_scaled_mm(A_vllm, B_vllm, A_scale_vllm, B_scale_vllm, diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index b5321f748e6be..c93f9d54d780c 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -418,6 +418,15 @@ __device__ inline T neg_inf() { return cuda_cast(-cuda::std::numeric_limits::infinity()); } +template +__device__ inline bool is_finite(const T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + template __device__ void topk_with_k2(T* output, T const* input, cg::thread_block_tile<32> const& tile, @@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel( // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; // The check is necessary to avoid abnormal input - if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) { + if (lane_id < n_group && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } @@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = - (i < num_experts_per_group) && - cuda::std::isfinite(scores_with_bias[offset + i]) - ? scores_with_bias[offset + i] - : neg_inf(); + T candidates = (i < num_experts_per_group) && + is_finite(scores_with_bias[offset + i]) + ? scores_with_bias[offset + i] + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu index f5b40e35b6e5a..91d489fdef862 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -12,8 +12,8 @@ #include "../vectorization_utils.cuh" #include "../../dispatch_utils.h" -__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { - unsigned mask = 0xffff; +__device__ __forceinline__ float GroupReduceMax(float val) { + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); @@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel( threads_per_group, // stride in group scalar_op_cache); // scalar handler - local_absmax = GroupReduceMax(local_absmax, lane_id); + local_absmax = GroupReduceMax(local_absmax); float y_s = local_absmax / max_8bit; if constexpr (SCALE_UE8M0) { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b6ee2656746c1..edf7aff1abaac 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -5,11 +5,14 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, const int64_t rows_per_block); -torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, +torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, const int64_t CuCount); -void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); +void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, at::Tensor& out_c, + const at::Tensor& scale_a, const at::Tensor& scale_b, + const int64_t CuCount); void paged_attention( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index eb47139208c91..52119d52f6d1e 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); } } @@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); } } @@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); + } } } } @@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + if (commitColumn[i]) { + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, + wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; #if defined(__HIP__MI3XX__) @@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) + if (commitColumn[i]) { + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][i] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + } C[m + i + n * M] = __float2s(sum[n][i]); + } } } } @@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 63) { for (int n = 0; n < N; n++) { for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + if (commitColumn[i]) { + if (BIAS) + sum4[n][i][0] += + __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, - const scalar_t* __restrict__ A, scalar_t* C, +__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, + const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } @@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) { return rtn; } -torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, +torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, const int64_t CuCount) { auto M_in = in_a.size(0); auto K_in = in_a.size(1); auto N_in = in_b.size(0); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); @@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } else if (K_in * N_in <= max_lds_len * 1.2) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitK_hf_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ wvSplitK_hf_big_ \ - <<>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ - CuCount); \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ } \ } @@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, using fptype = typename scalar::type; fptype* af4 = reinterpret_cast(in_a.data_ptr()); const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + const fptype* biasf4 = + (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); switch (N_in) { case 1: @@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, - const fp8_t* __restrict__ A, scalar_t* C, + wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx, + const int By, const fp8_t* B, const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { @@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { - C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + if (y + m >= M) break; // To avoid mem access fault. + sum[n][y][0] *= sA * sB; + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += + __bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]); + } + C[m + y + n * M] = __float2s(sum[n][y][0]); // * sA * sB); } } } @@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, - const fp8_t* B, const fp8_t* __restrict__ A, + const int Bx, const int By, const fp8_t* B, + const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { @@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, - const fp8_t* __restrict__ A, scalar_t* C, + wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx, + const int By, const fp8_t* B, const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE; @@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. - C[m + y + n * M] = __float2s(sum[n][y][0] * sA * sB); + sum[n][y][0] *= sA * sB; + if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]); + } else if constexpr (std::is_same_v) { + if (BIAS) + sum[n][y][0] += + __bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]); + } + C[m + y + n * M] = __float2s(sum[n][y][0]); } } } @@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) template __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, - const fp8_t* B, const fp8_t* __restrict__ A, - scalar_t* C, const float* __restrict__ s_A, + const int Bx, const int By, const fp8_t* B, + const fp8_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, scalar_t* C, + const float* __restrict__ s_A, const float* __restrict__ s_B, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE } #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support -void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - at::Tensor& scale_a, at::Tensor& scale_b, +void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, + const c10::optional& in_bias, at::Tensor& out_c, + const at::Tensor& scale_a, const at::Tensor& scale_b, const int64_t CuCount) { static c10::ScalarType kFp8Type = is_fp8_ocp() ? c10::ScalarType::Float8_e4m3fn @@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, auto K_in = in_a.size(1); auto N_in = in_b.size(0); auto Kp_in = in_a.stride(0); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; + TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); TORCH_CHECK(out_c.dtype() == torch::kFloat16 || @@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ wvSplitKQ_hf_sml_ \ - <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ + <<>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ + b_ptr, bias_ptr, c_ptr, s_a, s_b, \ + __wvPrGrp, CuCount); \ } else { \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ wvSplitKQ_hf_ \ - <<>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ - s_a, s_b, __wvPrGrp, CuCount); \ + <<>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \ + b_ptr, bias_ptr, c_ptr, s_a, s_b, \ + __wvPrGrp, CuCount); \ } \ } @@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { auto a_ptr = in_a.data_ptr(); auto b_ptr = in_b.data_ptr(); + auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; switch (N_in) { case 1: WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index c0c4daef64f05..518486b1ca5de 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( - "wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " + "wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> " "Tensor"); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); // wvSplitK for fp8 rocm_ops.def( - "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " + "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, " + "Tensor scale_a, " " Tensor scale_b, int CuCount) -> ()"); rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 004e75b204642..ce078bce0b753 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts): def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") parser.add_argument( "--method", type=str, @@ -60,6 +61,7 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -71,8 +73,7 @@ def parse_args(): return parser.parse_args() -def main(): - args = parse_args() +def main(args): args.endpoint_type = "openai-chat" model_dir = args.model_dir @@ -134,7 +135,7 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, - max_model_len=16384, + max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, ) @@ -198,6 +199,39 @@ def main(): acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 print(f"acceptance at token {i}: {acceptance_rate:.2f}") + return acceptance_length + if __name__ == "__main__": - main() + args = parse_args() + acceptance_length = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + ) diff --git a/tests/distributed/test_nccl_symm_mem_allreduce.py b/tests/distributed/test_nccl_symm_mem_allreduce.py new file mode 100644 index 0000000000000..ffc913742620d --- /dev/null +++ b/tests/distributed/test_nccl_symm_mem_allreduce.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.device_communicators.pynccl import ( + register_nccl_symmetric_ops) +from vllm.distributed.device_communicators.pynccl_allocator import ( + get_nccl_mem_pool, is_symmetric_memory_enabled) +from vllm.distributed.parallel_state import (get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + pynccl_comm = cuda_communicator.pynccl_comm + if get_nccl_mem_pool() is None: + pytest.skip("NCCL allocator compilation failed " + "(probably missing NCCL headers).") + if not is_symmetric_memory_enabled(): + pytest.skip("NCCL symmetric memory allreduce is disabled.") + + register_nccl_symmetric_ops(pynccl_comm) + input = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + input_clone = input.clone() + output = torch.ops.vllm.all_reduce_symmetric_with_copy(input) + assert output is not None + + group = get_tp_group().device_group + dist.all_reduce(input_clone, group=group) + torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="NCCLSymmMemAllreduce is only available for CUDA platforms.", +) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1") + monkeypatch.setenv("NCCL_NVLS_ENABLE", "1") + monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1") + + mp.spawn(nccl_symm_mem_allreduce_worker, + args=(world_size, ), + nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py index 5a804a389123b..83e1fe47aeec0 100644 --- a/tests/distributed/test_symm_mem_allreduce.py +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import queue import random import typing @@ -10,26 +11,31 @@ import torch.distributed as dist import torch.multiprocessing as mp import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, +from vllm.distributed.parallel_state import (get_tp_group, init_distributed_environment, initialize_model_parallel) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine from vllm.platforms import current_platform from vllm.utils import update_environment_variables torch.manual_seed(42) random.seed(44) -test_size_elements = 4 * 1024 * 1024 +test_size_elements = 1024 * 1024 -def symm_mem_allreduce_worker(local_rank: int, world_size: int): +def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): monkeypatch = pytest.MonkeyPatch() - with monkeypatch.context() as m: + config = VllmConfig(parallel_config=ParallelConfig( + tensor_parallel_size=world_size)) + + with monkeypatch.context() as m, set_current_vllm_config(config): m.delenv("CUDA_VISIBLE_DEVICES", raising=False) dtype = torch.bfloat16 device = torch.device(f"cuda:{local_rank}") @@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int): get_tp_group().device_communicator) symm_mem_comm = cuda_communicator.symm_mem_comm if symm_mem_comm is None or symm_mem_comm.disabled: - pytest.skip("SymmMemCommunicator is not available or disabled.") + # can't use skip under multiprocessing + q.put("SymmMemCommunicator is not available or disabled.") + return inp_direct_symm_mem = torch.randint(1, 23, (test_size_elements, ), dtype=dtype, device=device) if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): - pytest.skip( + # can't use skip under multiprocessing + q.put( "SymmMemCommunicator isn't used for this world and input size." ) + return original_inp_direct_symm_mem = inp_direct_symm_mem.clone() out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) assert out_direct_symm_mem is not None - group = get_tensor_model_parallel_group().device_group + group = get_tp_group().device_group dist.all_reduce(original_inp_direct_symm_mem, group=group) torch.testing.assert_close(out_direct_symm_mem, original_inp_direct_symm_mem, @@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") + q = mp.get_context('spawn').Queue() + mp.spawn(symm_mem_allreduce_worker, + args=(world_size, q), + nprocs=world_size) + try: + val = q.get(timeout=1) + except queue.Empty: + val = None + finally: + cleanup_dist_env_and_memory() + if val is not None: + pytest.skip(val) - # Enable SymmMemCommunicator - monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") - mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) - cleanup_dist_env_and_memory() +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch): + world_size = 4 + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + # Verify that the DataParallel runs without error + engine_args = EngineArgs(model="distilbert/distilgpt2", + enforce_eager=True, + enable_prefix_caching=True, + data_parallel_size=2, + tensor_parallel_size=2, + data_parallel_backend="mp") + LLMEngine.from_engine_args(engine_args) diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 69e96dfd2cb13..1325e6883132a 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -39,6 +39,8 @@ CUDA_DEVICES = [ # We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] +RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"] + @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_layers", NUM_LAYERS) @@ -223,6 +225,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) +@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -236,9 +239,13 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, kv_cache_layout: str, + implementation: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) + assert implementation in ["cuda", "triton"] + if implementation == "triton" and kv_cache_layout == "HND": + pytest.skip("Triton implementation only supports NHD layout.") # fp8 conversion requires continugous memory buffer. Reduce the number of # blocks and tokens to consume less memory. @@ -298,12 +305,20 @@ def test_reshape_and_cache_flash( cloned_key_cache = key_cache_compact.clone() cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, v_scale) + if implementation == "cuda": + opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, + (key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0])) + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, k_scale, + v_scale) + elif implementation == "triton": + from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash) + triton_reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, k_scale, + v_scale) key_cache_compact = permute_and_compact(key_cache) value_cache_compact = permute_and_compact(value_cache) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c440747316b80..c0b934fc55ae6 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -12,7 +12,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, get_col_major_tma_aligned_tensor, - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 @@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 720eee62760db..3d4c851a9b889 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -20,9 +20,11 @@ from vllm.platforms import current_platform (8, 513, 64), # Non-divisible (native only) ]) @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, - group_size: int, seed: int) -> None: + group_size: int, seed: int, + use_ue8m0: bool) -> None: """Test QuantFP8 group quantization with various configurations. Tests both CUDA and native implementations, column-major scales, @@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) # 1. Test native implementation (always available) x_quant_native, scales_native = quant_op.forward_native(x.clone()) @@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, # 2. Test column-major scales configuration quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x.clone()) - assert scales_col.shape == (expected_num_groups, batch_size) + assert scales_col.shape == (batch_size, expected_num_groups) + assert scales_col.stride(0) == 1 + assert scales_col.stride(1) == batch_size + + # Test column-major scales consistency + assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() -def test_quantfp8_group_multidimensional(seed: int) -> None: +def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: current_platform.seed_everything(seed) group_size = 64 @@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: group_shape = GroupShape(1, group_size) quant_op = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=False) + column_major_scales=False, + use_ue8m0=use_ue8m0) x_quant, scales = quant_op.forward_native(x_3d.clone()) assert x_quant.shape == x_3d.shape @@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None: # Test column_major_scales with multi-dim quant_op_col = QuantFP8(static=False, group_shape=group_shape, - column_major_scales=True) + column_major_scales=True, + use_ue8m0=use_ue8m0) _, scales_col = quant_op_col.forward_native(x_3d.clone()) assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index a9b1c71ef0718..6de5fc9c56010 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + import pytest import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - rocm_per_tensor_w8a8_scaled_mm_impl) from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] @@ -49,6 +49,7 @@ NKM_FACTORS_WVSPLITK_FP8 = [ (2, 512, 512), (3, 2048, 2048), (4, 4096, 4096), + (4, 16400, 2048), # Extended FP8 dimensions not covered by WVSPLITK (1, 14336, 1024), (2, 24576, 2048), @@ -67,6 +68,9 @@ SEEDS = [0] @torch.inference_mode() def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): torch.manual_seed(seed) + #TODO: Zero-centering the inputs causes errors for LLMM1! + # Without that the numbers quickly saturate, and may + # be giving false matches. A = torch.rand(n, k, dtype=dtype, device="cuda") B = torch.rand(m, k, dtype=dtype, device="cuda") @@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() - A = torch.rand(n, k, dtype=dtype, device="cuda") - B = torch.rand(m, k, dtype=dtype, device="cuda") + A = torch.rand(n, k, dtype=dtype, device="cuda") - .5 + B = torch.rand(m, k, dtype=dtype, device="cuda") - .5 - ref_out = torch.matmul(A, B.t()) - out = ops.wvSplitK(B, A, cu_count) + ref_out = torch.nn.functional.linear(A, B) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) + + assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test for rocm") +def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): + torch.manual_seed(seed) + cu_count = current_platform.get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier + BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) assert torch.allclose(out, ref_out, rtol=0.01) @@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - A = torch.rand(n, k, device="cuda") - B = torch.rand(m, k, device="cuda") + A = torch.rand(n, k, device="cuda") - 0.5 + B = torch.rand(m, k, device="cuda") - 0.5 A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) @@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), reason="only test for rocm fp8") -def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias): +def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) - A = torch.rand(n, k, device="cuda") - B = torch.rand(m, k, device="cuda") + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, device="cuda") - .5) * xavier + B = (torch.rand(m, k, device="cuda") - .5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5 A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None - - output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a, - scale_b, bias) ref_out = torch._scaled_mm(A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, - bias=bias) - assert torch.allclose(output, ref_out, rtol=0.01) + bias=BIAS) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, + current_platform.get_cu_count(), BIAS) + + assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 92ce10a9efc0b..200b6ecd58528 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -17,8 +17,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.layernorm import (RMSNorm, dispatch_rocm_rmsnorm_func, fused_add_rms_norm, rms_norm) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() -@pytest.mark.skipif( - not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), - reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") -@pytest.mark.parametrize("use_cutlass", [True, False]) -@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) -def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, - use_rocm_aiter_gemm_w8a8_blockscale: str, - monkeypatch): - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", - use_rocm_aiter_gemm_w8a8_blockscale) - - use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( - int(use_rocm_aiter_gemm_w8a8_blockscale))) - block_scale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) - if use_cutlass: - assert block_scale_func == cutlass_scaled_mm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == ( - torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) - else: - assert block_scale_func == w8a8_block_fp8_matmul - - @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c0ab3fbb10622..af8c7ec3b4822 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -18,6 +18,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp) from vllm.model_executor.layers.quantization.utils.quant_utils import ( cutlass_fp4_supported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) assert perplexity <= exp_perplexity + + +def test_compressed_tensors_fp8_block_enabled(vllm_runner): + model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" + with vllm_runner(model_path) as llm: + + fp8_dtype = current_platform.fp8_dtype() + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) + assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, + W8A8BlockFp8LinearOp) + + assert qkv_proj.weight.dtype is fp8_dtype + assert qkv_proj.weight_scale.dtype is torch.float32 + assert len(qkv_proj.weight.shape) == 2 + assert len(qkv_proj.weight_scale.shape) == 2 + + input_quant_op = \ + qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + assert isinstance(input_quant_op, QuantFP8) + assert input_quant_op._forward_method == input_quant_op.forward_cuda + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index c74dbb3ebb17e..7d7a46910be89 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -5,11 +5,12 @@ import pytest import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS -from tests.v1.attention.utils import create_common_attn_metadata +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, split_attn_metadata) +from vllm.v1.worker.ubatch_utils import create_ubatch_slices @pytest.fixture @@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert results[1].num_reqs == mid_point assert results[1].num_actual_tokens == mid_point assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) + + +@pytest.mark.parametrize( + "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", + [ + # Split in the middle of request 1 + ([32, 40], [8, 8], 12, 2, 1), + # Split inside the first request + ([32, 40], [8, 8], 4, 1, 2), + ], +) +def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point, + expected_first_reqs, + expected_second_reqs): + """Test splitting a prefill across ubatches""" + import numpy as np + + device = torch.device("cpu") + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) + common = create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + num_scheduled_tokens = np.array(query_lens, dtype=np.int32) + qsl_np = common.query_start_loc_cpu.numpy() + num_tokens = common.num_actual_tokens + + ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) + assert len(ubatch_slices) == 2 + + first_meta = _make_metadata_with_slice(ubatch_slices[0], common) + second_meta = _make_metadata_with_slice(ubatch_slices[1], common) + + # Token counts match the split + assert first_meta.num_actual_tokens == split_point + assert second_meta.num_actual_tokens == num_tokens - split_point + + # Number of requests per ubatch + assert first_meta.num_reqs == expected_first_reqs + assert second_meta.num_reqs == expected_second_reqs + + # Identify which request is split and how many tokens are in the first chunk + split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1) + tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx]) + orig_q_lens = (common.query_start_loc_cpu[1:] - + common.query_start_loc_cpu[:-1]) + + # Check query length continuity: first-chunk + second-chunk == original qlen + # First ubatch last request query length + qlen_first_last = int(first_meta.query_start_loc_cpu[-1] - + first_meta.query_start_loc_cpu[-2]) + # Second ubatch first request query length + qlen_second_first = int(second_meta.query_start_loc_cpu[1] - + second_meta.query_start_loc_cpu[0]) + assert qlen_first_last == tokens_in_first_chunk + assert qlen_first_last + qlen_second_first == int( + orig_q_lens[split_req_idx]) + + # Check seq_lens adjustments + # Context lengths per original request + context_lens = [s - q for s, q in zip(seq_lens, query_lens)] + + # First ubatch: last request's seq_len should be + # context + tokens_in_first_chunk + expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk + assert int(first_meta.seq_lens[-1]) == expected_seqlen + + # For full preceding requests in first ubatch, seq_lens should match + # originals + for i in range(first_meta.num_reqs - 1): + assert int(first_meta.seq_lens[i]) == seq_lens[i] + + # Second ubatch: first request (continuation) seq_len should be full + # original + assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx] + # Any following full requests in second ubatch should match originals + for j in range(1, second_meta.num_reqs): + # Map to original request index + orig_idx = split_req_idx + j + assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx] diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 4b0f3b2d9967e..e2c686928cea1 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -5,6 +5,7 @@ from __future__ import annotations import json +from dataclasses import fields from enum import Enum from typing import TYPE_CHECKING, Any @@ -21,7 +22,8 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager -from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.sampling_params import (GuidedDecodingParams, SamplingParams, + StructuredOutputsParams) if TYPE_CHECKING: from vllm.config import TokenizerMode @@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str: return json.loads(s) +def test_guided_decoding_deprecated(): + with pytest.warns(DeprecationWarning, + match="GuidedDecodingParams is deprecated.*"): + guided_decoding = GuidedDecodingParams(json_object=True) + + structured_outputs = StructuredOutputsParams(json_object=True) + assert fields(guided_decoding) == fields(structured_outputs) + + with pytest.warns(DeprecationWarning, + match="guided_decoding is deprecated.*"): + sp1 = SamplingParams(guided_decoding=guided_decoding) + + with pytest.warns(DeprecationWarning, + match="guided_decoding is deprecated.*"): + sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding) + + assert sp1 == sp2 + assert sp1.structured_outputs == guided_decoding + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( "model_name, backend, tokenizer_mode, speculative_config", diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index e7f6b68fc3f77..5096f9fd647bd 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -532,9 +532,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [ + proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ attn_metadata_builder - ] + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder) result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, @@ -659,9 +660,10 @@ def test_propose_tree(spec_token_tree): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [ + proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ attn_metadata_builder - ] + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder) # Setup inputs for the proposer. target_token_ids = torch.randint(0, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 712295aa92886..a108542e14368 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor, return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) -def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: - return torch.ops._rocm_C.wvSplitK(a, b, cu_count) +def wvSplitK(a: torch.Tensor, + b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) -def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cu_count: int) -> torch.Tensor: +def wvSplitKQ(a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None) -> torch.Tensor: out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) - torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py new file mode 100644 index 0000000000000..0b0c706626af3 --- /dev/null +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + + +@triton.jit +def reshape_and_cache_kernel_flash( + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + key_cache_ptr, # [num_blocks, block_size, num_heads, head_size] + value_cache_ptr, # [num_blocks, block_size, num_heads, head_size] + slot_mapping_ptr, # [num_tokens] + k_scale, # float32 + v_scale, # float32 + # strides + key_stride: tl.int64, + value_stride: tl.int64, + block_stride: tl.int64, + page_stride: tl.int64, + num_heads: tl.constexpr, + head_size: tl.constexpr, + block_size: tl.constexpr, + # FP8 flags + FP8_KV_CACHE: tl.constexpr, + # tune parameters + TILE_SIZE: tl.constexpr, +): + + token_idx = tl.program_id(axis=0) + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) + if slot_idx < 0: + # Padding token that should be ignored. + return + + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + tile_pos = tile_i * TILE_SIZE + tile_offs + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + src_key_idx = token_idx * key_stride + src_value_idx = token_idx * value_stride + + tgt_idx = block_idx * block_stride + block_offset * page_stride + + # [TILE_SIZE] + key_load = tl.load(key_ptr + src_key_idx + tile_pos, + mask=tile_pos < (num_heads * head_size)) + if FP8_KV_CACHE: + if key_load.dtype.is_fp8(): + key_tile = key_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load / tl.load(k_scale) + else: + key_tile = key_load + + # [TILE_SIZE] + value_load = tl.load(value_ptr + src_value_idx + tile_pos, + mask=tile_pos < (num_heads * head_size)) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load + + tl.store( + key_cache_ptr + tgt_idx + tile_pos, + key_tile, + mask=tile_pos < (num_heads * head_size), + ) + tl.store( + value_cache_ptr + tgt_idx + tile_pos, + value_tile, + mask=tile_pos < (num_heads * head_size), + ) + return + + +def triton_reshape_and_cache_flash( + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size] + # [num_blocks, block_size, num_heads, head_size] + key_cache: torch.Tensor, + # [num_blocks, block_size, num_heads, head_size] + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 +): + num_tokens = key.shape[0] + num_heads = key.shape[1] + head_size = key.shape[2] + block_size = key_cache.shape[1] + n = num_heads * head_size + + key_stride = key.stride()[0] + value_stride = value.stride()[0] + block_stride = key_cache.stride()[0] + page_stride = key_cache.stride()[1] + + head_stride = key_cache.stride()[2] + assert head_stride == head_size, "only continous heads are supported" + + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \ + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." + kv_cache_torch_dtype = current_platform.fp8_dtype() if \ + kv_cache_dtype.startswith("fp8") else key_cache.dtype + + if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith( + "fp8"): + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\ + "uint8 is not supported by triton reshape_and_cache_flash" + + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ + torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, + torch.float8_e4m3fnuz], \ + "unsupported dtype of KV cache tensor, got "\ + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \ + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." + + # heuristics instead of autotuning + TILE_SIZE = min(2048, triton.next_power_of_2(n)) + if torch.version.hip: + num_stages = 4 + num_warps = 8 + else: # cuda + num_stages = 10 + num_warps = 16 + if torch.cuda.get_device_capability(key.device)[0] < 9: + TILE_SIZE = min(512, TILE_SIZE) + + # TODO(ngl): maybe replace with static launch grid to avoid overhead if + # using cudagraphs + grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) + + reshape_and_cache_kernel_flash[grid]( + key_ptr=key, + value_ptr=value, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + slot_mapping_ptr=slot_mapping, + k_scale=k_scale, + v_scale=v_scale, + # strides + key_stride=key_stride, + value_stride=value_stride, + block_stride=block_stride, + page_stride=page_stride, + num_heads=num_heads, + head_size=head_size, + block_size=block_size, + # FP8 flags + FP8_KV_CACHE=FP8_KV_CACHE, + # autotune parameters + TILE_SIZE=TILE_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index e233f959c0a4a..befb7736d75af 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -12,6 +12,8 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform @@ -154,6 +156,10 @@ class CUDAGraphWrapper: stack.enter_context( patch("torch.cuda.empty_cache", lambda: None)) + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. with torch.cuda.graph(cudagraph, pool=self.graph_pool): # `output` is managed by pytorch's cudagraph pool diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 92fc68f8927ca..d786d3e289b33 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -509,8 +509,15 @@ class VllmConfig: if self.compilation_config.cudagraph_mode is None: if envs.VLLM_USE_V1 and self.compilation_config.level \ == CompilationLevel.PIECEWISE: + # default to full and piecewise for most models self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE + CUDAGraphMode.FULL_AND_PIECEWISE + + # pooling model does not support full cudagraphs + if self.model_config is not None and \ + self.model_config.pooler_config is not None: + self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE @@ -638,11 +645,13 @@ class VllmConfig: if self.parallel_config.enable_dbo: a2a_backend = envs.VLLM_ALL2ALL_BACKEND - assert a2a_backend == "deepep_low_latency", \ - "Microbatching currently only supports the deepep_low_latency "\ - f"all2all backend. {a2a_backend} is not supported. To fix set "\ - "the VLLM_ALL2ALL_BACKEND environment variable to "\ - "deepep_low_latency and install the DeepEP kerenls." + assert a2a_backend in \ + ["deepep_low_latency", "deepep_high_throughput"], \ + "Microbatching currently only supports the deepep_low_latency and "\ + f"deepep_high_throughput all2all backend. {a2a_backend} is not "\ + "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\ + "variable to deepep_low_latency or deepep_high_throughput and "\ + "install the DeepEP kernels." if not self.instance_id: self.instance_id = random_uuid()[:5] @@ -685,6 +694,23 @@ class VllmConfig: # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True + def has_blocked_weights(): + if self.quant_config is not None: + if hasattr(self.quant_config, "weight_block_size"): + return self.quant_config.weight_block_size is not None + elif hasattr(self.quant_config, "has_blocked_weights"): + return self.quant_config.has_blocked_weights() + return False + + # Enable quant_fp8 CUDA ops (TODO disable in follow up) + # On H100 the CUDA kernel is faster than + # native implementation + # https://github.com/vllm-project/vllm/issues/25094 + if has_blocked_weights(): + custom_ops = self.compilation_config.custom_ops + if "none" not in custom_ops and "-quant_fp8" not in custom_ops: + custom_ops.append("+quant_fp8") + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 34fa7fcfe7e87..0441745e8b36e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -228,15 +228,14 @@ class CompilationConfig: The mode of the cudagraph: - NONE, no cudagraph capture. - - PIECEWISE. (v1 default) + - PIECEWISE. - FULL. - FULL_DECODE_ONLY. - - FULL_AND_PIECEWISE. + - FULL_AND_PIECEWISE. (v1 default) PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph incompatible ops (i.e. some attention ops) outside the cudagraph for general flexibility. - This is the default mode. FULL mode: Capture full cudagraph for all batches. Can be good for small models or workloads with small prompts; not supported by many backends. @@ -249,7 +248,7 @@ class CompilationConfig: FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and piecewise cudagraph for prefill and mixed prefill-decode batches. - This is like the most performant mode for most models. + This is the most performant mode for most models and is the default. Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the diff --git a/vllm/config/model.py b/vllm/config/model.py index 33e5d3ea04a48..d8a8fe20fd030 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1003,6 +1003,7 @@ class ModelConfig: self.quantization = quantization_override break + quant_method = quant_method if quant_method != "" else None # Verify quantization configurations. if self.quantization is None: self.quantization = quant_method diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index a84d882430166..f80eb1adc7fd3 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -139,12 +139,18 @@ class ParallelConfig: """Disable the custom all-reduce kernel and fall back to NCCL.""" enable_dbo: bool = False - """Enable microbatching for the model executor.""" + """Enable dual batch overlap for the model executor.""" dbo_decode_token_threshold: int = 32 - """The threshold for microbatching. If the number of tokens in the - request is greater than this threshold, microbatching will be used. - Otherwise, the request will be processed in a single batch.""" + """The threshold for dual batch overlap for batches only containing decodes. + If the number of tokens in the request is greater than this threshold, + microbatching will be used. Otherwise, the request will be processed in a + single batch.""" + dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune + """The threshold for dual batch overlap for batches that contain one or more + prefills. If the number of tokens in the request is greater than this + threshold, microbatching will be used. Otherwise, the request will be + processed in a single batch.""" ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 149df73d8667b..ae18429f62518 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any +from typing import Any, Optional import torch import torch.distributed as dist +import vllm.envs as envs from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger @@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_rdma_bytes = None num_qps_per_rank = None if self.internode: - num_rdma_bytes = 1024 * 1024 * 1024 + num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = self.num_sms // 2 else: num_rdma_bytes = 0 @@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) - # It is dangerous to set num sms outside this function. num_sms is not - # a part of the hash-key that identifies this object. If we are in a - # situation where we make objects with different num_sms, the hash key - # in get_or_create must be updated. - handle.set_num_sms(self.num_sms) return handle + def set_num_sms(self, num_sms: int): + import deep_ep + + # Right now the buffers are sized for only what the kernels were + # created with. So we can only reduce the number of SMS used + # but not increase it. + if num_sms > self.num_sms: + num_sms = self.num_sms + deep_ep.Buffer.set_num_sms(num_sms) + class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): """ @@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): import deep_ep # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = num_local_experts num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, @@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) return handle + + # DeepEP LL uses RDMA so no SMs are used for communication + def max_sms_used(self) -> Optional[int]: + return 0 \ No newline at end of file diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 805a88854b77c..87e0f8e1a9677 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -10,8 +10,9 @@ import sys import tempfile from collections.abc import Sequence from itertools import product -from typing import Optional +from typing import Any, Optional +import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -56,6 +57,30 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = { } } +NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = { + "min_world_size": 4, + "thresholds": { + 4: 2 * MiB, # 2 MB + 8: 1 * MiB, # 1 MB + }, + "always_use_above_world_size": 8 # Always use symm mem for world_size > 8 +} + + +def should_nccl_symm_mem_allreduce(world_size: int, + input_tensor: torch.Tensor) -> bool: + from vllm.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled) + if not is_symmetric_memory_enabled(): + return False + if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: + return False + threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size) + if threshold is not None and input_tensor.nbytes >= threshold: + return True + return (world_size + > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]) + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 01f59b44a0e69..586441c917830 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -60,6 +60,12 @@ class All2AllManagerBase: # and reuse it for the same config. raise NotImplementedError + def set_num_sms(self, num_sms: int): + pass + + def max_sms_used(self) -> Optional[int]: + return None # None means it could use the whole GPU + def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): raise NotImplementedError diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index b2bf3bc3cc2ed..b20e79f577c35 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -7,6 +7,12 @@ import torch from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.device_communicators.all_reduce_utils import ( + should_nccl_symm_mem_allreduce) +from vllm.distributed.device_communicators.pynccl import ( + register_nccl_symmetric_ops) +from vllm.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -24,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase): unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) if "tp" not in unique_name: - # only tp uses custom allreduce + # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False + use_torch_symm_mem = False else: from vllm.distributed.parallel_state import ( _ENABLE_CUSTOM_ALL_REDUCE) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM # ep does not use pynccl use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce + self.use_torch_symm_mem = use_torch_symm_mem # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -53,11 +62,13 @@ class CudaCommunicator(DeviceCommunicatorBase): group=self.cpu_group, device=self.device, ) + if is_symmetric_memory_enabled(): + register_nccl_symmetric_ops(self.pynccl_comm) self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None - if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, device=self.device, @@ -107,6 +118,13 @@ class CudaCommunicator(DeviceCommunicatorBase): raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): + # since currently we perform copy input -> symm_input -> out-of-place AR + # return symm_output, we don't need to check if input is symmetric + if self.pynccl_comm is not None and \ + should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_): + out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) + if out is not None: + return out # always try quick reduce first, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 3e4d0d250af94..75de85e1b0aba 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -17,6 +17,39 @@ from vllm.utils import current_stream logger = init_logger(__name__) +_NCCL_SYMM_OPS_REGISTERED = False + + +def register_nccl_symmetric_ops(pynccl_comm): + from vllm.distributed.device_communicators.pynccl_allocator import ( + nccl_symm_mem_context) + from vllm.utils import direct_register_custom_op + + global _NCCL_SYMM_OPS_REGISTERED + if _NCCL_SYMM_OPS_REGISTERED: + return + _NCCL_SYMM_OPS_REGISTERED = True + + def all_reduce_symmetric_with_copy_impl( + input_tensor: torch.Tensor) -> torch.Tensor: + with nccl_symm_mem_context(pynccl_comm): + symm_input = torch.empty_like(input_tensor) + symm_output = torch.empty_like(input_tensor) + symm_input.copy_(input_tensor) + symm_output = pynccl_comm.all_reduce(symm_input, symm_output) + return symm_output + + def all_reduce_symmetric_with_copy_fake( + input_tensor: torch.Tensor) -> torch.Tensor: + return torch.empty_like(input_tensor) + + direct_register_custom_op( + op_name="all_reduce_symmetric_with_copy", + op_func=all_reduce_symmetric_with_copy_impl, + mutates_args=[], + fake_impl=all_reduce_symmetric_with_copy_fake, + ) + class PyNcclCommunicator: @@ -67,6 +100,7 @@ class PyNcclCommunicator: self.available = True self.disabled = False + self.nccl_version = self.nccl.ncclGetRawVersion() logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) if self.rank == 0: @@ -109,6 +143,7 @@ class PyNcclCommunicator: def all_reduce(self, in_tensor: torch.Tensor, + out_tensor: torch.Tensor = None, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: if self.disabled: @@ -120,7 +155,8 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {in_tensor.device}") - out_tensor = torch.empty_like(in_tensor) + if out_tensor is None: + out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() @@ -288,3 +324,18 @@ class PyNcclCommunicator: def group_end(self): self.nccl.ncclGroupEnd() + + def register_comm_window(self, tensor: torch.Tensor): + return self.nccl.ncclCommWindowRegister( + self.comm, + buffer_type(tensor.data_ptr()), + tensor.numel() * tensor.element_size(), + 1, + ) + + def register_comm_window_raw(self, ptr: int, size: int): + return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), + size, 1) + + def deregister_comm_window(self, window): + return self.nccl.ncclCommWindowDeregister(self.comm, window) diff --git a/vllm/distributed/device_communicators/pynccl_allocator.py b/vllm/distributed/device_communicators/pynccl_allocator.py new file mode 100644 index 0000000000000..bc874c1e197e7 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_allocator.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import atexit +import contextlib +import tempfile +from typing import Any, Optional + +import torch +from packaging import version +from torch.cuda.memory import CUDAPluggableAllocator +from torch.utils.cpp_extension import load_inline + +from vllm import envs +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import find_nccl_include_paths + +logger = init_logger(__name__) + +nccl_allocator_source = """ +#include +extern "C" { + +void* nccl_alloc_plug(size_t size, int device, void* stream) { + void* ptr; + ncclResult_t err = ncclMemAlloc(&ptr, size); + return ptr; + +} + +void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { + ncclResult_t err = ncclMemFree(ptr); +} + +} +""" + +_allocator = None +_allocator_wrapper = None +_mem_pool = None +_registered_base_addrs = set() +_graph_pool_id = None +_nccl_allocator_failed_to_compile = False +_cached_pool_snapshot = None + + +def is_symmetric_memory_enabled(): + global _nccl_allocator_failed_to_compile + return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile + + +def is_symmetric_memory_tensor(tensor: torch.Tensor): + if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None: + return False + for segment in _cached_pool_snapshot: + for block in segment["blocks"]: + if block["address"] == tensor.untyped_storage().data_ptr(): + return True + return False + + +def set_graph_pool_id(graph_pool_id): + global _graph_pool_id + _graph_pool_id = graph_pool_id + + +def compile_nccl_allocator(): + global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile + if not current_platform.is_cuda(): + _nccl_allocator_failed_to_compile = True + return + try: + out_dir = tempfile.gettempdir() + nccl_allocator_libname = "nccl_allocator" + nccl_include_paths = find_nccl_include_paths() + load_inline( + name=nccl_allocator_libname, + cpp_sources=nccl_allocator_source, + with_cuda=True, + extra_ldflags=["-lnccl"], + verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG", + is_python_module=False, + build_directory=out_dir, + extra_include_paths=nccl_include_paths, + ) + _allocator_wrapper = CUDAPluggableAllocator( + f"{out_dir}/{nccl_allocator_libname}.so", + "nccl_alloc_plug", + "nccl_free_plug", + ) + _allocator = _allocator_wrapper.allocator() + except Exception as e: + _nccl_allocator_failed_to_compile = True + logger.warning( + "Failed to compile NCCL memory allocator. " + "Symmetric memory will be disabled. " + "This is expected if NCCL headers are not available. " + "optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory " + "containing the NCCL header. " + "Error: %s", str(e)) + + +def get_nccl_mem_pool(): + global _mem_pool, _nccl_allocator_failed_to_compile + if _mem_pool is None and not _nccl_allocator_failed_to_compile: + compile_nccl_allocator() + if _allocator is not None: + _mem_pool = torch.cuda.MemPool(_allocator) + return _mem_pool + + +def _cleanup_nccl_mem_pool(): + global _mem_pool + _mem_pool = None + + +def _cleanup_nccl_allocator_wrapper(): + global _allocator_wrapper + _allocator_wrapper = None + + +atexit.register(_cleanup_nccl_mem_pool) +atexit.register(_cleanup_nccl_allocator_wrapper) + + +class nccl_symm_mem_context: + + def __init__( + self, + pynccl_comm: PyNcclCommunicator, + disabled: bool = False, + ): + self.disabled = (disabled or not is_symmetric_memory_enabled() + or pynccl_comm.world_size == 1 + or not current_platform.is_cuda() + or get_nccl_mem_pool() is None or version.parse( + torch.__version__) < version.parse("2.8.0.a0")) + if self.disabled: + self.pynccl_comm: Optional[PyNcclCommunicator] = None + self._mem_pool_ctx: contextlib.AbstractContextManager[ + Any] = contextlib.nullcontext() + self.is_graph_capture = None + self.device = None + else: + self.pynccl_comm = pynccl_comm + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) + self.is_graph_capture = torch.cuda.is_current_stream_capturing() + self.device = torch.cuda.current_device() + + def __enter__(self): + if self.disabled: + return self + assert ( + self.pynccl_comm + is not None), "Symmetric memory requires pynccl to be initalized" + assert ( + self.pynccl_comm.nccl_version >= 22703 + ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + if self.is_graph_capture: + assert ( + _graph_pool_id + is not None), "graph_pool_id is not set under graph capture" + # Pause graph memory pool to use symmetric memory with cuda graph + torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) + self._mem_pool_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disabled: + return + global _cached_pool_snapshot + global _registered_base_addrs + self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) + _pool = get_nccl_mem_pool() + assert _pool is not None + _cached_pool_snapshot = _pool.snapshot() + assert self.pynccl_comm is not None + for segment in _cached_pool_snapshot: + if segment["address"] not in _registered_base_addrs: + self.pynccl_comm.register_comm_window_raw( + segment["address"], segment["total_size"]) + _registered_base_addrs.add(segment["address"]) + if self.is_graph_capture: + torch._C._cuda_beginAllocateCurrentThreadToPool( + self.device, _graph_pool_id) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index a930b63bc26ff..c3e99e177e2d5 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -41,6 +41,7 @@ logger = init_logger(__name__) ncclResult_t = ctypes.c_int ncclComm_t = ctypes.c_void_p +ncclWindow_t = ctypes.c_void_p class ncclUniqueId(ctypes.Structure): @@ -222,6 +223,24 @@ class NCCLLibrary: Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); Function("ncclGroupEnd", ncclResult_t, []), + # ncclResult_t ncclCommWindowRegister( + # ncclComm_t comm, void* buff, size_t size, + # ncclWindow_t* win, int winFlags); + Function( + "ncclCommWindowRegister", + ncclResult_t, + [ + ncclComm_t, + buffer_type, + ctypes.c_size_t, + ctypes.POINTER(ncclWindow_t), + ctypes.c_int, + ], + ), + # ncclResult_t ncclCommWindowDeregister( + # ncclComm_t comm, ncclWindow_t win); + Function("ncclCommWindowDeregister", ncclResult_t, + [ncclComm_t, ncclWindow_t]), ] # class attribute to store the mapping from the path to the library @@ -271,10 +290,14 @@ class NCCLLibrary: error_str = self.ncclGetErrorString(result) raise RuntimeError(f"NCCL error: {error_str}") - def ncclGetVersion(self) -> str: + def ncclGetRawVersion(self) -> int: version = ctypes.c_int() self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) - version_str = str(version.value) + # something like 21903 + return version.value + + def ncclGetVersion(self) -> str: + version_str = str(self.ncclGetRawVersion()) # something like 21903 --> "2.19.3" major = version_str[0].lstrip("0") minor = version_str[1:3].lstrip("0") @@ -375,6 +398,17 @@ class NCCLLibrary: def ncclGroupEnd(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type, + size: int, win_flags: int) -> ncclWindow_t: + window = ncclWindow_t() + self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"]( + comm, buff, size, ctypes.byref(window), win_flags)) + return window + + def ncclCommWindowDeregister(self, comm: ncclComm_t, + window: ncclWindow_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8c7a1b413cdb7..556a490ffa109 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -330,6 +330,8 @@ class EngineArgs: enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = \ ParallelConfig.dbo_decode_token_threshold + dbo_prefill_token_threshold: int = \ + ParallelConfig.dbo_prefill_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb expert_placement_strategy: ExpertPlacementStrategy = \ @@ -698,6 +700,9 @@ class EngineArgs: parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"]) + parallel_group.add_argument( + "--dbo-prefill-token-threshold", + **parallel_kwargs["dbo_prefill_token_threshold"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--eplb-config", @@ -1316,6 +1321,7 @@ class EngineArgs: enable_expert_parallel=self.enable_expert_parallel, enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, + dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, expert_placement_strategy=self.expert_placement_strategy, diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 57e4bb1e1da52..0c1c9c3192fc0 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -317,7 +317,8 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: ) output_items.append(response_item) elif recipient is not None and (recipient.startswith("python") - or recipient.startswith("browser")): + or recipient.startswith("browser") + or recipient.startswith("container")): for content in message.content: reasoning_item = ResponseReasoningItem( id=f"rs_{random_uuid()}", diff --git a/vllm/envs.py b/vllm/envs.py index ee5efff8bcd92..50d58c5468f97 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -182,15 +182,19 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" + VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 + VLLM_DBO_COMM_SMS: int = 20 GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None + VLLM_USE_NCCL_SYMM_MEM: bool = False + VLLM_NCCL_INCLUDE_PATH: Optional[str] = None def get_default_cache_root(): @@ -1366,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": @@ -1392,6 +1396,15 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER"), + # The size in MB of the buffers (NVL and RDMA) used by DeepEP + "VLLM_DEEPEP_BUFFER_SIZE_MB": + lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")), + + # The number of SMs to allocate for communication kernels when running DBO + # the rest of the SMs on the device will be allocated to compute + "VLLM_DBO_COMM_SMS": + lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), + # Valid values are container,code_interpreter,web_search_preview # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": @@ -1399,6 +1412,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ["container", "code_interpreter", "web_search_preview"]), + + # Flag to enable NCCL symmetric memory allocation and registration + "VLLM_USE_NCCL_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))), + + # NCCL header path + "VLLM_NCCL_INCLUDE_PATH": + lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), + } # --8<-- [end:env-vars-definition] diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 6cf5815ef12da..ed294b0aedaf4 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -24,11 +24,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): super().__init__() self.base_layer = base_layer self.input_size = self.base_layer.input_size + # Ensure tp_size and tp_rank consistency with the base_layer. + self.tp_size = self.base_layer.tp_size + self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(self.base_layer) self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - self.output_slices: tuple[int, ...] - self.tp_size: int self.output_size: int self.n_slices: int diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index fa4eb272a69fe..6284576446c8f 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -8,9 +8,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.utils import divide from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -85,7 +83,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # inconsistent when TP is greater than 1. self.is_merged_col_linear = type( base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() self.output_size = self.base_layer.output_size_per_partition # There is only one LoRA layer self.n_slices = 1 @@ -97,22 +94,20 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # Applicable to cases where the base_layer is # MergedColumnParallelLinear. if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.output_size // 2 offset = lora_b.shape[0] // 2 - left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * + left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) * shard_size, :] - right_weight = lora_b[offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size, :] + right_weight = lora_b[offset + self.tp_rank * shard_size:offset + + (self.tp_rank + 1) * shard_size, :] lora_b = torch.cat([left_weight, right_weight], dim=0) # Applicable to cases where the base_layer is # ColumnParallelLinear. else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size lora_b = lora_b[start_idx:end_idx, :] return lora_b @@ -120,10 +115,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # TODO: Fix the slicing logic of bias. if bias is None: return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size bias = bias[start_idx:end_idx] return bias @@ -144,7 +138,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): # Matrix multiply. output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: + if self.base_layer.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -185,8 +179,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): QKVParallelLinear]) -> None: super().__init__(base_layer) # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() # the output_sizes in MergedColumnParallelLinear is not sharded by tp # we need to divide it by the tp_size to get correct slices size output_sizes = self.base_layer.output_sizes @@ -341,9 +333,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas lora_b_q = lora_b[self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * (self.q_shard_id + 1), :] @@ -397,8 +389,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): super().__init__(base_layer) # There are three LoRA layer. self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) @@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): # Therefore, the sharding of `lora_a` only needs to correspond with the # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a @@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): """ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size lora_a = lora_a[start_idx:start_idx + shard_size, :] return lora_a diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 3356297c1537a..18a8f13ed9427 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: ReplicatedLinear) -> None: super().__init__(base_layer, ) # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 self.output_size = self.base_layer.output_size self.n_slices = 1 diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index cac2c92136dca..d468655e629ae 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -8,9 +8,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, +from vllm.distributed import (split_tensor_along_last_dim, tensor_model_parallel_all_reduce) # yapf: disable from vllm.model_executor.layers.linear import RowParallelLinear @@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__(base_layer) - self.tp_size = get_tensor_model_parallel_world_size() # reset input_size self.input_size = self.base_layer.input_size_per_partition self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() # There is only one LoRA layer. self.n_slices = 1 @@ -68,12 +63,12 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): else: # TODO: simplify code below splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) + input_, num_partitions=self.tp_size) input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + if self.base_layer.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: output_ = output_parallel @@ -154,8 +149,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): buffer, x, self.lora_a_stacked, 1.0) if not current_platform.can_update_inplace(): buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) + if self.tp_size>1: + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce # by adding the column partitioned lora output to a slice of output diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index e3198fb3d3ae4..90e18217d28be 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -48,11 +48,11 @@ class LoRALayerWeights: @property def input_dim(self) -> int: - return self.lora_a.shape[0] + return self.lora_a.shape[1] @property def output_dim(self) -> int: - return self.lora_b.shape[1] + return self.lora_b.shape[0] @property def is_packed(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index a250a6218715e..9e9a9afc18a03 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) from vllm.utils import round_up +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm, + dbo_switch_to_compute, dbo_switch_to_compute_sync, + dbo_yield_and_switch_from_comm_to_compute, + dbo_yield_and_switch_from_compute_to_comm) class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.async_prepare = True # The dispatch function returns a handle that the combine function - # requires. We store the handle here so it is available to the - # combine function. - self.handle = None + # requires. Under DBO microbatching we must track one handle per + # micro-batch to avoid races between threads. + self.handles = [None, None] # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] @@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): has_scales = token_scales is not None + # We yield before launching the dispatch kernel since the dispatch + # kernel will block the CPU so we want to queue up all the compute + # for the other ubatch before the dispatch kernel starts. + dbo_yield_and_switch_from_compute_to_comm() + (num_tokens_per_rank, num_tokens_per_rdma_rank, dispatch_expert_num_tokens, is_token_in_rank, event) = self.buffer.get_dispatch_layout( @@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ( token_data, expert_topk_ids, expert_topk_weights, - expert_num_tokens_per_expert_list, self.handle, event + expert_num_tokens_per_expert_list, handle, event ) = self.buffer.dispatch( x=token_data, handle=None, @@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=self.async_prepare, + async_finish=self.async_prepare and not dbo_enabled(), allocate_on_comm_stream=False) + # record the handle for this ubatch + a2a_idx = dbo_current_ubatch_id() + self.handles[a2a_idx] = handle + + dbo_switch_to_compute_sync() + return lambda: self._receiver( event, has_scales, @@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1_scale: Optional[torch.Tensor], quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - if self.async_prepare: + if event.event is not None: event.current_stream_wait() if has_scales: @@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[Callable, mk.ReceiverType]: + ) -> mk.ReceiverType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1q_scale = None a1_post_scale = quant_config.a1_scale - return (lambda *args: None, - self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config)) + return self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config) def prepare( self, @@ -252,10 +267,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - (_, receiver) = self.prepare_async(a1, topk_weights, topk_ids, - num_experts, expert_map, - apply_router_weight_on_input, - quant_config) + receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts, + expert_map, apply_router_weight_on_input, + quant_config) return receiver() def _finalize( @@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_async: bool, ) -> Optional[Callable]: - assert self.handle is not None + a2a_idx = dbo_current_ubatch_id() + handle = self.handles[a2a_idx] + assert handle is not None # fused_expert_output can have 0 tokens - This happens when none of the # tokens from the all2all reach this EP rank. @@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, ) - + dbo_yield_and_switch_from_compute_to_comm() combined_x, _, event = self.buffer.combine( x=fused_expert_output, - handle=self.handle, + handle=handle, topk_weights=None, config=self._get_combine_config(), previous_event=None, - async_finish=do_async, + async_finish=do_async and not dbo_enabled(), allocate_on_comm_stream=False) + dbo_switch_to_compute() + if do_async: def _receiver(): - event.current_stream_wait() + if event.event is not None: + event.current_stream_wait() + dbo_switch_to_comm() # Respect inplace outputs. output.copy_(combined_x, non_blocking=True) - return lambda: _receiver() + # TODO(lucas): refactor the modular kernel so this will be + # handled there + dbo_yield_and_switch_from_comm_to_compute() + + return _receiver else: + # TODO(lucas): support this case with the refactored modular kernel + assert not dbo_enabled() # Respect inplace outputs. output.copy_(combined_x, non_blocking=True) return None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 101fc8798c427..a9554291db69c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, do_async: bool, - ) -> Optional[Callable]: + ) -> tuple[Callable, Callable]: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") @@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return_recv_hook=do_recv_hook, out=output) - return recv_hook + return recv_hook, lambda: None def finalize_async( self, @@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> Callable: - recv_hook = self._finalize( + ) -> tuple[Callable, Callable]: + return self._finalize( output, fused_expert_output, topk_weights, @@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): weight_and_reduce_impl, do_async=True, ) - assert recv_hook is not None - return recv_hook def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5fce24018e647..4ba14196682a5 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv -from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook, +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, + dbo_maybe_run_recv_hook, dbo_register_recv_hook, dbo_yield) # @@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC): expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[Callable, ReceiverType]: + ) -> Union[tuple[Callable, ReceiverType], ReceiverType]: """ Perform any quantization (and/or) dispatching needed for this kernel but do not wait for results from other workers. @@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC): - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - Returns a callback that when invoked waits for results from other - workers and has the same return signature as `prepare`, e.g. + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as + `prepare`, if a hook is returned this is more lightweight check that + the recv is complete without doing extra work (used by DBO, will be + refactored in the very near future) + + e.g. - receiver = obj.prepare_async(...) + ret = obj.prepare_async(...) + + if isinstance(ret, tuple): + hook, receiver = ret + hook() + + if hook is not None: a, a_scales, expert_meta, topk_ids, topk_weights = receiver() is equivalent to: @@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: TopKWeightAndReduce, - ) -> Callable: + ) -> Union[tuple[Callable, Callable], Callable]: """ Perform any combine plus apply weights and perform a reduction on the fused experts output but do not wait for results from other workers. @@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC): - weight_and_reduce_impl: An optional TopKWeightAndReduce implementation. - Returns a callback that when invoked waits for results from other - workers and has the same return signature as `finalize`, e.g. + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as + `finalize`, if a hook is returned this is more lightweight check that + the recv is complete without doing extra work (used by DBO, will be + refactored in the very near future) - receiver = obj.finalize_async(output, ...) + ret = obj.finalize_async(output, ...) ... output not valid yet ... + if isinstance(ret, tuple): + hook, receiver = ret + hook() receiver() ... output valid here ... @@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module): layer due to any layer specific state that may be used by the component objects. """ - fused_out_buffer = SharedResizableBuffer() - workspace13_buffer = SharedResizableBuffer() - workspace2_buffer = SharedResizableBuffer() + + class SharedBuffers: + + def __init__(self) -> None: + self.fused_out = SharedResizableBuffer() + self.workspace13 = SharedResizableBuffer() + self.workspace2 = SharedResizableBuffer() + + # Persistent buffers that are shared across `FusedMoEModularKernel` + # instances (layers), to save memory and allocattions. + # + # We have two sets of buffers to support dual batch overlap (DBO) where each + # microbatch (ubatch) should use its own set of buffers to avoid + # cross-ubatch contimination. + # NOTE that memory is lazily allocated for these buffers, meaning that if + # DBO isn't being used, the second SharedBuffers will be empty. + shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] def __init__( self, @@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module): a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) + # select per-ubatch buffers to avoid cross-ubatch reuse under DBO + ubatch_idx = dbo_current_ubatch_id() + buffers = self.shared_buffers[ubatch_idx] + # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = self.workspace13_buffer.get(workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = self.workspace2_buffer.get(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + workspace13 = buffers.workspace13.get(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = buffers.workspace2.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") @@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module): (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) - fused_out = self.fused_out_buffer.get(fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + ubatch_idx = dbo_current_ubatch_id() + buffers = self.shared_buffers[ubatch_idx] + fused_out = buffers.fused_out.get(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) def slice_input_tensors( chunk_idx: int @@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module): if not self.prepare_finalize.supports_async(): # We shouldn't be running an a2a kernel that doesn't # support async prepare/finalize + # TODO(lucas): enable in follow-up assert not dbo_enabled() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, @@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module): else: # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() - hook, receiver = self.prepare_finalize.prepare_async( + prepare_ret = self.prepare_finalize.prepare_async( a1, topk_weights, topk_ids, @@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module): self.fused_experts.quant_config, ) - # If DBO is being used, register the hook with the ubatch context - # and call it in dbo_maybe_run_recv_hook instead of passing it to - # the receiver. - dbo_register_recv_hook(hook) - dbo_yield() - if not dbo_enabled(): - hook() + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = prepare_ret \ + if isinstance(prepare_ret, tuple) else (None, prepare_ret) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = receiver() @@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module): if self.shared_experts is not None: shared_output = self.shared_experts(a1) else: - recv_hook = self.prepare_finalize.finalize_async( + finalize_ret = self.prepare_finalize.finalize_async( output, fused_out, topk_weights, @@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module): if self.shared_experts is not None: shared_output = self.shared_experts(a1) - assert recv_hook is not None - dbo_register_recv_hook(recv_hook) - dbo_yield() - if not dbo_enabled(): - recv_hook() + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = finalize_ret \ + if isinstance(finalize_ret, tuple) else (None, finalize_ret) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() + + receiver() if self.shared_experts is None: return output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d6550dd16892f..3f771ea2abd1a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -644,6 +644,14 @@ class CompressedTensorsConfig(QuantizationConfig): # If no matches, return None return None + def has_blocked_weights(self) -> bool: + for scheme in self.target_scheme_map.values(): + weight_quant = scheme.get("weights") + if (weight_quant is not None + and weight_quant.strategy == QuantizationStrategy.BLOCK): + return True + return False + @staticmethod def supports_cutlass_24( weight_quant: Optional[QuantizationArgs], diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 85adae32f4cdc..10f9085be4d12 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,6 +13,7 @@ from compressed_tensors.quantization import (ActivationOrdering, import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, @@ -31,6 +32,9 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + expert_weight_is_col_major, get_col_major_tma_aligned_tensor, + requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -45,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used logger = init_logger(__name__) @@ -505,10 +510,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.weight_quant.strategy == QuantizationStrategy.CHANNEL and self.input_quant.strategy == QuantizationStrategy.TOKEN) if not (per_tensor or per_channel): - raise ValueError( - "For FP8 Fused MoE layers, we require per tensor " - "or channelwise, dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None self.static_input_scales = not self.input_quant.dynamic if self.static_input_scales and per_channel: @@ -519,7 +526,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + and not self.block_quant) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False @@ -531,8 +539,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( self.weight_quant, self.input_quant) - self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( - self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) + self.use_cutlass = not self.block_quant and ( + quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) + or self.is_fp8_w8a8_sm100) self.disable_expert_map = False def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -547,6 +556,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 + and intermediate_size_per_partition % block_k != 0): + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, @@ -602,6 +636,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * + ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + # INPUT_SCALES if self.static_input_scales: w13_input_scale = torch.nn.Parameter(torch.ones( @@ -706,6 +761,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): del layer.w2_input_scale if self.use_cutlass: + assert self.weight_quant.strategy != QuantizationStrategy.BLOCK device = layer.w13_weight.device # ab_strides1 and c_strides2 are the same self.ab_strides1_c_strides2 = torch.full( @@ -724,6 +780,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): device=device, dtype=torch.int64) + if is_deep_gemm_e8m0_used() and self.block_quant: + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if expert_weight_is_col_major(layer.w13_weight_scale): + layer.w13_weight_scale = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale) + if expert_weight_is_col_major(layer.w2_weight_scale): + layer.w2_weight_scale = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale) + def maybe_make_prepare_finalize( self) -> Optional[mk.FusedMoEPrepareAndFinalize]: if self.use_marlin or self.rocm_aiter_moe_enabled: @@ -777,9 +856,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): return experts # triton path - from vllm.model_executor.layers.fused_moe import TritonExperts - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) assert not self.rocm_aiter_moe_enabled and not self.use_marlin @@ -790,14 +870,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert max_num_tokens_per_rank is not None logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonExperts( + return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, ) else: - logger.debug("TritonExperts(%s)", self.__class__.__name__) - return TritonExperts(self.moe_quant_config) + logger.debug("TritonOrDeepGemmExperts(%s)", + self.__class__.__name__) + return TritonOrDeepGemmExperts(self.moe_quant_config, + allow_deep_gemm=True) def get_fused_moe_quant_config( self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: @@ -816,6 +898,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): a2_scale=layer.w2_input_scale, per_act_token_quant=per_act_token, per_out_ch_quant=per_channel_quant, + block_shape=layer.weight_block_size, ) def apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index d42ae22c51393..fa0816959fcda 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -11,7 +11,7 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_block_linear, check_aiter_fp8_linear_support, + W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, @@ -41,16 +41,30 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape) self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + else: + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + if self.weight_block_size is not None: + assert not self.is_static_input_scheme + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) + @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -141,13 +155,14 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if layer.weight_block_size is not None: - return apply_fp8_block_linear( - layer, + if self.weight_block_size is not None: + return self.w8a8_block_fp8_linear.apply( input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported) + ) return self.fp8_linear.apply(input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index d26a932eddb2c..c2b3ccf19fca8 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs( return M, N, K, C -def w8a8_block_fp8_matmul_deepgemm( +def w8a8_deepgemm_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm( return C -def w8a8_block_fp8_matmul_deepgemm_fake( +def w8a8_deepgemm_block_scaled_mm_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake( direct_register_custom_op( - op_name="w8a8_block_fp8_matmul_deepgemm", - op_func=w8a8_block_fp8_matmul_deepgemm, + op_name="w8a8_deepgemm_block_scaled_mm", + op_func=w8a8_deepgemm_block_scaled_mm, mutates_args=[], - fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, + fake_impl=w8a8_deepgemm_block_scaled_mm_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index aec9c79f1ea82..c4951712baa78 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -31,12 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_block_linear, check_aiter_fp8_linear_support, + W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, - create_fp8_weight_parameter, get_col_major_tma_aligned_tensor, - maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, - process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, - validate_fp8_block_shape) + create_fp8_weight_parameter, expert_weight_is_col_major, + get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, + requant_weight_ue8m0_inplace, validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -64,12 +64,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) -def _is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m - - class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -240,15 +234,28 @@ class Fp8LinearMethod(LinearMethodBase): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" - # Use per-token quantization for better perf if dynamic and cutlass - if not self.act_q_static and cutlass_fp8_supported(): - self.act_q_group_shape = GroupShape.PER_TOKEN + if self.weight_block_size: + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: - self.act_q_group_shape = GroupShape.PER_TENSOR + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR - self.fp8_linear = Fp8LinearOp( - act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) + if self.block_quant: + assert not self.act_q_static + assert self.weight_block_size is not None + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + else: + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -397,12 +404,15 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias) if self.block_quant: - return apply_fp8_block_linear( - layer, + assert self.weight_block_size is not None + + return self.w8a8_block_fp8_linear.apply( input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported) + ) return self.fp8_linear.apply(input=x, weight=layer.weight, @@ -660,10 +670,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): - if _is_col_major(layer.w13_weight_scale_inv): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) - if _is_col_major(layer.w2_weight_scale_inv): + if expert_weight_is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = \ get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) @@ -811,10 +821,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) # Ensure column-major TMA alignment expected by DeepGEMM. - if _is_col_major(layer.w13_weight_scale_inv): + if expert_weight_is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w13_weight_scale_inv) - if _is_col_major(layer.w2_weight_scale_inv): + if expert_weight_is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w2_weight_scale_inv) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 31182f40b48f6..ece3e5817116f 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -27,11 +27,14 @@ class QuantFP8(CustomOp): This CustomOp supports both static and dynamic quantization. """ - def __init__(self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None, - column_major_scales: bool = False): + def __init__( + self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False, + use_ue8m0: Optional[bool] = None, # for Torch compile + ): """ :param static: static or dynamic quantization :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, @@ -46,6 +49,7 @@ class QuantFP8(CustomOp): self.group_shape = group_shape self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales + self.use_ue8m0 = use_ue8m0 self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: @@ -70,7 +74,8 @@ class QuantFP8(CustomOp): x, group_size=self.group_size, column_major_scales=self.column_major_scales, - dtype=_FP8_DTYPE) + dtype=_FP8_DTYPE, + use_ue8m0=self.use_ue8m0) assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape @@ -137,7 +142,10 @@ class QuantFP8(CustomOp): x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() - scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + scales_raw = absmax / _FP8_MAX + if self.use_ue8m0: + scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw))) + scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR) x_scaled = x_grouped / scales x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) @@ -151,6 +159,6 @@ class QuantFP8(CustomOp): scales = scales.reshape(orig_shape[:-1] + (num_groups, )) if self.column_major_scales: - scales = scales.transpose(-2, -1).contiguous() + scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index fc12483de0c0e..2098086bf2401 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,8 +13,9 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + GroupShape, group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.model_executor.parameter import (BlockQuantScaleParameter, @@ -24,6 +25,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +# We need to pass in the is_hopper flag as argument because the function +# current_platform.is_device_capability() is not supported by Torch compiler. def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -42,15 +46,17 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, + is_hopper: Optional[bool] = None, ) -> torch.Tensor: + if is_hopper is None: + is_hopper = current_platform.is_device_capability(90) return ops.cutlass_scaled_mm( A, B.T, out_dtype=output_dtype, scale_a=As, # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None - and current_platform.is_device_capability(90) else Bs.T) + scale_b=Bs if block_size is not None and is_hopper else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -98,122 +104,189 @@ if current_platform.is_rocm(): aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) -def dispatch_w8a8_blockscale_func( - use_cutlass: bool, use_aiter_and_is_supported: bool -) -> Callable[[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - list[int], - torch.dtype, -], torch.Tensor]: - if use_cutlass: - return cutlass_scaled_mm - if (use_aiter_and_is_supported): - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale - return w8a8_block_fp8_matmul +# TODO we should be able to change the type of block_size to GroupShape +# after we resolve GroupShape compilation issue +# https://github.com/vllm-project/vllm/issues/25270 +def _w8a8_triton_block_scaled_mm_func( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, + block_size, output_dtype) + + +def _w8a8_triton_block_scaled_mm_fake( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty((qx.size(0), weight.size(0)), + dtype=output_dtype, + device=qx.device) + + +direct_register_custom_op( + "w8a8_triton_block_scaled_mm_func", + _w8a8_triton_block_scaled_mm_func, + mutates_args=[], + fake_impl=_w8a8_triton_block_scaled_mm_fake, + dispatch_key="CUDA", +) # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 -def apply_w8a8_block_fp8_linear( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype +class W8A8BlockFp8LinearOp: + """ + This class executes a Blocked FP8 linear layer using cutlass if supported + and torch.scaled_mm otherwise. + """ - if should_use_deepgemm_for_fp8_linear(output_dtype, weight): + def __init__( + self, + weight_group_shape: GroupShape, + act_quant_group_shape: GroupShape, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, + ): + self.weight_group_shape = weight_group_shape + self.act_quant_group_shape = act_quant_group_shape + self.is_deep_gemm_supported = is_deep_gemm_supported() + self.is_hopper = current_platform.is_device_capability(90) + # Get the correct blockscale mul and input quant operations. + # We can't use _dispatch_w8a8_blockscale_op to figure out if we want + # to use deepgemm because we don't know the shape of weights (and + # whether deepgemm supports it) at the init time. + self.w8a8_blockscale_op, self.input_quant_op = \ + self._dispatch_w8a8_blockscale_op( + cutlass_block_fp8_supported, use_aiter_and_is_supported) + self.deepgemm_input_quant_op = (QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported + else None) + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype - q_input, x_scale = per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - ) + if should_use_deepgemm_for_fp8_linear(output_dtype, weight, + self.is_deep_gemm_supported): + output = self._run_deepgemm(input, weight, weight_scale) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + def _run_deepgemm( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 - output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + assert self.deepgemm_input_quant_op is not None + q_input, x_scale = self.deepgemm_input_quant_op(input_2d) + return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm( q_input, weight, x_scale, weight_scale, - block_size, - output_dtype=output_dtype) - if bias is not None: - output += bias - return output.to(dtype=output_dtype).view(*output_shape) + self.weight_group_shape, + output_dtype=input_2d.dtype) - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - cutlass_block_fp8_supported, use_aiter_and_is_supported) - if cutlass_block_fp8_supported: - num_pad = 0 - if current_platform.is_device_capability(90): - # pad first dimension to be divisible by 4 due to - # cutlass blockwise gemm limitation for hopper - num_pad = 4 - (input_2d.shape[0] % 4) - if num_pad > 0: - input_2d = torch.nn.functional.pad(input_2d, - (0, 0, 0, num_pad), - "constant", 0) - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True) - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) - if num_pad > 0: - output = output[:-num_pad] - else: - if use_aiter_and_is_supported: - q_input, x_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + def _run_cutlass( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + if self.is_hopper: + # We pad unconditionally (even if shape is already divisible by 4) + # to support dynamic shape for input_2d.shape[0] in torch.compile + x = torch.nn.functional.pad(input_2d, + (0, 0, 0, -input_2d.shape[0] % 4)) else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False) + x = input_2d - output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, - block_size, input.dtype) + q_input, x_scale = self.input_quant_op(x) + output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale, + list(self.weight_group_shape), + input_2d.dtype, self.is_hopper) + output = output[0:input_2d.shape[0], ...] + return output - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + def _run_aiter( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.act_quant_group_shape == GroupShape(1, 128) + q_input, x_scale = aiter_per1x128_quant( + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + q_input, weight, x_scale, weight_scale, self.weight_group_shape, + input_2d.dtype) + def _run_triton( + self, + input_2d: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + assert self.input_quant_op is not None + q_input, x_scale = self.input_quant_op(input_2d) + return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( + q_input, weight, x_scale, weight_scale, self.weight_group_shape, + input_2d.dtype) -def apply_w8a8_block_fp8_linear_fake( - input: torch.Tensor, - weight: torch.Tensor, - block_size: list[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, - use_aiter_and_is_supported: bool = False, -) -> torch.Tensor: - output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - - -if not current_platform.is_cpu(): - direct_register_custom_op( - op_name="apply_w8a8_block_fp8_linear", - op_func=apply_w8a8_block_fp8_linear, - mutates_args=[], - fake_impl=apply_w8a8_block_fp8_linear_fake, - ) + def _dispatch_w8a8_blockscale_op( + self, + use_cutlass: bool, + use_aiter_and_is_supported: bool, + ) -> tuple[Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], torch.Tensor], Optional[QuantFP8]]: + if use_cutlass: + return self._run_cutlass, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=False)) + if use_aiter_and_is_supported: + return self._run_aiter, None + return self._run_triton, (QuantFP8(False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False)) def input_to_float8( @@ -465,7 +538,7 @@ def per_token_group_quant_fp8( @triton.jit -def _w8a8_block_fp8_matmul( +def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output A, B, @@ -590,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, return None -def w8a8_block_fp8_matmul( +def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -650,7 +723,7 @@ def w8a8_block_fp8_matmul( return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_block_fp8_matmul[grid]( + _w8a8_triton_block_scaled_mm[grid]( A, B, C, @@ -997,20 +1070,7 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, layer.weight_scale.data.T.contiguous(), requires_grad=False) -def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, - bias: Optional[torch.Tensor], - cutlass_block_fp8_supported: bool, - use_aiter_and_is_supported: bool) -> torch.Tensor: - """Apply block-wise FP8 linear operation.""" - assert layer.weight_block_size is not None - - return torch.ops.vllm.apply_w8a8_block_fp8_linear( - input=input, - weight=layer.weight, - block_size=layer.weight_block_size, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_block_fp8_supported=cutlass_block_fp8_supported, - use_aiter_and_is_supported=use_aiter_and_is_supported, - ) +def expert_weight_is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8cda1789e6c97..6ed482db4700e 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx - if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( - ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None: + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \ + qinput.shape[0] == 1 and \ + qinput.shape[1] % 16 == 0 and \ + ((bias is None) or (bias.dtype == out_dtype)) : output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, - current_platform.get_cu_count()) + current_platform.get_cu_count(), bias) else: output = torch._scaled_mm(qinput, weight, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a1675ffbaa950..d7a65d43c2107 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl( k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ x.dtype in [torch.float16, torch.bfloat16] \ - and k % 8 == 0 and bias is None) + and k % 8 == 0) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) @@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl( cu_count = current_platform.get_cu_count() if m > 8 and 0 < n <= 4: - out = ops.wvSplitK(weight, x_view, cu_count) + out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.view(*x.shape[:-1], weight.shape[0]) - elif m % 4 == 0 and n == 1 and k <= 8192: + elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: out = ops.LLMM1(weight, x_view, 4) return out.view(*x.shape[:-1], weight.shape[0]) return torch.nn.functional.linear(x, weight, bias) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ce3d23763ed64..aa7bcf5b65ada 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -266,24 +266,24 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): if structured_outputs_config.reasoning_parser == "": structured_outputs_config.reasoning_parser = "openai_gptoss" - # Increase the max capture size from 512 to 1024 for performance. + # Increase the max capture size from 512 to 992 for performance. # NOTE(woosuk): This will increase the number of CUDA graphs - # from 67 to 83. + # from 67 to 81. scheduler_config = vllm_config.scheduler_config if len(scheduler_config.cuda_graph_sizes) == 1: max_capture_size = scheduler_config.cuda_graph_sizes[0] # FIXME(woosuk): When using full cuda graph with FA3, the max # supported size is 992. - if max_capture_size < 1024: + if max_capture_size < 992: cuda_graph_sizes = [1, 2, 4] # Step size 8 for small batch sizes cuda_graph_sizes += [i for i in range(8, 256, 8)] # Step size 16 for larger batch sizes - cuda_graph_sizes += [i for i in range(256, 1025, 16)] + cuda_graph_sizes += [i for i in range(256, 993, 16)] scheduler_config.cuda_graph_sizes = cuda_graph_sizes logger.info( "Overriding max cuda graph capture size to " - "%d for performance.", 1024) + "%d for performance.", 992) class MambaModelConfig(VerifyAndUpdateConfig): diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index dfae3c3ea5437..2ff2d54a83aa8 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -134,6 +134,11 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) self.model = LlamaModel(vllm_config=vllm_config, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index fb10af6c53c90..b99a1547918ee 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -203,6 +203,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): nn.Module.__init__(self) self.config = vllm_config. \ speculative_config.draft_model_config.hf_config + # Ensure draft_vocab_size is set + # default to the base vocab size when absent + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 4d1829cd228cd..f6df85a50238c 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module( """ assert isinstance(m, FusedMoE) w13 = m.w13_weight - w13_s = m.w13_weight_scale_inv + w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale) w2 = m.w2_weight - w2_s = m.w2_weight_scale_inv + w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale) num_topk = m.top_k assert isinstance(w13, torch.Tensor) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index efe70d019ccc6..f424682f9dfa0 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampling parameters for text generation.""" import copy +import warnings from dataclasses import field from enum import Enum, IntEnum from functools import cached_property @@ -59,6 +60,19 @@ class StructuredOutputsParams: f"but multiple are specified: {self.__dict__}") +@dataclass +class GuidedDecodingParams(StructuredOutputsParams): + + def __post_init__(self): + warnings.warn( + "GuidedDecodingParams is deprecated. This will be removed in " + "v0.12.0 or v1.0.0, which ever is soonest. Please use " + "StructuredOutputsParams instead.", + DeprecationWarning, + stacklevel=2) + return super().__post_init__() + + class RequestOutputKind(Enum): # Return entire output so far in every RequestOutput CUMULATIVE = 0 @@ -179,6 +193,8 @@ class SamplingParams( # Fields used to construct logits processors structured_outputs: Optional[StructuredOutputsParams] = None """Parameters for configuring structured outputs.""" + guided_decoding: Optional[GuidedDecodingParams] = None + """Deprecated alias for structured_outputs.""" logit_bias: Optional[dict[int, float]] = None """If provided, the engine will construct a logits processor that applies these logit biases.""" @@ -227,6 +243,7 @@ class SamplingParams( ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, structured_outputs: Optional[StructuredOutputsParams] = None, + guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, extra_args: Optional[dict[str, Any]] = None, @@ -238,6 +255,15 @@ class SamplingParams( int(token): min(100.0, max(-100.0, bias)) for token, bias in logit_bias.items() } + if guided_decoding is not None: + warnings.warn( + "guided_decoding is deprecated. This will be removed in " + "v0.12.0 or v1.0.0, which ever is soonest. Please use " + "structured_outputs instead.", + DeprecationWarning, + stacklevel=2) + structured_outputs = guided_decoding + guided_decoding = None return SamplingParams( n=1 if n is None else n, @@ -334,6 +360,16 @@ class SamplingParams( # eos_token_id is added to this by the engine self._all_stop_token_ids.update(self.stop_token_ids) + if self.guided_decoding is not None: + warnings.warn( + "guided_decoding is deprecated. This will be removed in " + "v0.12.0 or v1.0.0, which ever is soonest. Please use " + "structured_outputs instead.", + DeprecationWarning, + stacklevel=2) + self.structured_outputs = self.guided_decoding + self.guided_decoding = None + def _verify_args(self) -> None: if not isinstance(self.n, int): raise ValueError(f"n must be an int, but is of " diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 3399d00fbabbd..5d165f1662383 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1383,6 +1383,38 @@ def find_nccl_library() -> str: return so_file +def find_nccl_include_paths() -> Optional[list[str]]: + """ + We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` + environment variable, or we find the library file brought by + nvidia-nccl-cuXX. load_inline by default uses + torch.utils.cpp_extension.include_paths + """ + paths: list[str] = [] + inc = envs.VLLM_NCCL_INCLUDE_PATH + if inc and os.path.isdir(inc): + paths.append(inc) + + try: + import importlib.util + spec = importlib.util.find_spec("nvidia.nccl") + if spec and getattr(spec, "submodule_search_locations", None): + for loc in spec.submodule_search_locations: + inc_dir = os.path.join(loc, "include") + if os.path.exists(os.path.join(inc_dir, "nccl.h")): + paths.append(inc_dir) + except Exception: + pass + + seen = set() + out: list[str] = [] + for p in paths: + if p and p not in seen: + out.append(p) + seen.add(p) + return out or None + + prev_set_stream = torch.cuda.set_stream _current_stream_tls = threading.local() diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 4083193d7650e..2f533ca0639fc 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools import importlib import os -from typing import Any, Callable, NoReturn +from typing import Any, Callable, NoReturn, Optional import torch @@ -172,9 +172,13 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim -def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, - weight: torch.Tensor): - return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 +def should_use_deepgemm_for_fp8_linear( + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: Optional[bool] = None): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return (supports_deep_gemm and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 722c23f150cd3..f9fbd05efc67d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -8,6 +8,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger @@ -291,7 +293,13 @@ class TritonAttentionImpl(AttentionImpl): if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - ops.reshape_and_cache_flash( + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( key, value, key_cache, @@ -303,8 +311,9 @@ class TritonAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) + if key_cache.dtype != self.fp8_dtype: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape assert layer._q_scale_float == 1.0, \ "A non 1.0 q_scale is not currently supported." diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6ef489f5a7a28..f837439f953e8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -107,19 +107,57 @@ def _make_metadata_with_slice( the requests included in ubatch_slice """ + assert not ubatch_slice.is_empty(), ( + f"Ubatch slice {ubatch_slice} is empty") + request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice + start_locs = attn_metadata.query_start_loc_cpu + first_req = request_slice.start + first_tok = token_slice.start + last_req = request_slice.stop - 1 + last_tok = token_slice.stop - 1 + + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \ + "Token slice start outside of first request" + assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \ + "Token slice end outside of last request" + + # If the "middle" request has tokens in both ubatches, we have to split it. + # If ubatch_slice is the first ubatch then we will be splitting the last + # request. If it's the second microbatch, then we will be splitting the + # first request + splits_first_request = first_tok > start_locs[first_req] + splits_last_request = last_tok < start_locs[last_req + 1] - 1 + + query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, request_slice) + assert len(query_start_loc) >= 2, ( f"query_start_loc must have at least 2 elements, " f"got {len(query_start_loc)}") - query_start_loc_cpu = slice_query_start_locs( - attn_metadata.query_start_loc_cpu, request_slice) + if splits_first_request: + tokens_skipped = first_tok - start_locs[first_req] + query_start_loc[1:] -= tokens_skipped + query_start_loc_cpu[1:] -= tokens_skipped seq_lens = attn_metadata.seq_lens[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + + if splits_last_request: + tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop + query_start_loc[-1] -= tokens_skipped + query_start_loc_cpu[-1] -= tokens_skipped + + # Make sure we don't modify the seq_lens tensors + # (not cudagraph compatible) + seq_lens = seq_lens.clone() + seq_lens_cpu = seq_lens_cpu.clone() + seq_lens[-1] -= tokens_skipped + seq_lens_cpu[-1] -= tokens_skipped + max_seq_len = int(seq_lens_cpu.max()) num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ request_slice] @@ -167,6 +205,7 @@ def split_attn_metadata( for ubatch_slice in ubatch_slices: results.append( _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + return results @@ -696,7 +735,6 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc97d5c8f39d4..a9e0a38fe3417 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,6 +9,7 @@ import numpy as np import torch import torch.nn as nn +from vllm.attention.backends.abstract import AttentionMetadataBuilder from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) @@ -30,7 +31,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -78,6 +78,8 @@ class EagleProposer: self.is_multimodal_model = vllm_config.model_config \ .is_multimodal_model + self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -118,7 +120,7 @@ class EagleProposer: with_numpy=True) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + self.allowed_attn_types: tuple[type, ...] if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -191,11 +193,12 @@ class EagleProposer: assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] - attn_metadata = attn_metadata_builder.build_for_drafting( + # Select the correct attention metadata builders for EAGLE layers. + # Get the attention metadata builders once and reuse for later. + builder = (self._get_attention_metadata_builder() + if self.attn_metadata_builder is None else + self.attn_metadata_builder) + attn_metadata = builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0) # At this moment, we assume all eagle layers belong to the same KV @@ -329,11 +332,9 @@ class EagleProposer: exceeds_max_model_len, PADDING_SLOT_ID) # Rebuild attention metadata - attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] - attn_metadata = attn_metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=token_index + 1) + attn_metadata = builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata @@ -538,9 +539,8 @@ class EagleProposer: hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - ubatch_id = dbo_current_ubatch_id() tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + self.runner.attn_groups[0][0].get_metadata_builder() assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) @@ -854,10 +854,24 @@ class EagleProposer: # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_language_model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_language_model.lm_head + if self.vllm_config.speculative_config.method != "eagle3": + if hasattr(target_language_model, "lm_head"): + logger.info( + "Loading EAGLE LM head weights from the target model.") + self.model.lm_head = target_language_model.lm_head + else: + if (hasattr(self.model, "lm_head") + and hasattr(target_language_model, "lm_head") + and self.model.lm_head.weight.shape + == target_language_model.lm_head.weight.shape): + logger.info("Assuming the EAGLE head shares the same lm_head" + " with the target model.") + del self.model.lm_head + self.model.lm_head = target_language_model.lm_head + else: + logger.info( + "The EAGLE head's lm_head will be loaded separately" + " from the target model.") @torch.inference_mode() def dummy_run( @@ -880,6 +894,31 @@ class EagleProposer: inputs_embeds=inputs_embeds, ) + def _get_attention_metadata_builder( + self) -> list[AttentionMetadataBuilder]: + """Find and return the attention metadata builders for EAGLE layers. + + Returns: + The metadata builders for EAGLE layers. + + Raises: + AssertionError: If no metadata builders are found for EAGLE layers. + """ + builder = None + chosen_layer = self.attn_layer_names[0] + + for kv_cache_group in self.runner.attn_groups: + for attn_group in kv_cache_group: + if chosen_layer in attn_group.layer_names: + builder = attn_group.get_metadata_builder() + break + if builder is not None: + break + + assert builder is not None, ( + "Failed to find attention metadata builder for EAGLE layers.") + return builder + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 89b9a3c34f2ac..f785824958147 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split +from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, + ubatch_split) from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens_padded = num_tokens_unpadded + self.get_local_padding( num_tokens_unpadded) ubatch_slices, num_tokens_after_padding = \ - ubatch_split(max_num_scheduled_tokens, + ubatch_split(num_scheduled_tokens, num_tokens_unpadded, num_tokens_padded, self.vllm_config) @@ -1176,9 +1177,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): encoder_seq_lens=encoder_seq_lens, ) - if self.speculative_config and \ - spec_decode_common_attn_metadata is None: - spec_decode_common_attn_metadata = common_attn_metadata + if (self.speculative_config + and spec_decode_common_attn_metadata is None): + if isinstance(self.drafter, EagleProposer): + if (self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names): + spec_decode_common_attn_metadata = common_attn_metadata + else: + spec_decode_common_attn_metadata = common_attn_metadata for attn_group in self.attn_groups[kv_cache_group_id]: # Prepare for cascade attention if enabled & beneficial. @@ -1206,7 +1212,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ubatch_slices, common_attn_metadata) for ubid, common_attn_metadata in enumerate( common_attn_metadata_list): - assert common_attn_metadata.max_query_len == 1 attn_metadata_i = (attn_group.get_metadata_builder( ubatch_id=ubid).build( common_prefix_len=common_prefix_len, @@ -2182,9 +2187,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) = self._preprocess(scheduler_output, intermediate_tensors, ubatch_slices, num_tokens_after_padding) - if ubatch_slices is not None: - num_input_tokens = num_input_tokens // 2 - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( num_scheduled_tokens @@ -2194,6 +2196,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode, batch_descriptor = \ self.cudagraph_dispatcher.dispatch(batch_descriptor) + # This is currently to get around the assert in the DPMetadata + # where it wants `num_tokens_across_dp` to align with `num_tokens` + if ubatch_slices is not None: + num_input_tokens = ubatch_slices[0].num_tokens + # Run the model. # Use persistent buffers for CUDA graphs. with (set_forward_context( @@ -2360,7 +2367,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[torch.Tensor], + aux_hidden_states: Optional[list[torch.Tensor]], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: @@ -2380,6 +2387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: indices = [] offset = 0 + assert spec_decode_metadata is not None for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids): @@ -2430,6 +2438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO(woosuk): Support M-RoPE. target_positions = self.positions.gpu[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) @@ -2455,6 +2464,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO(woosuk): Support M-RoPE. target_positions = self.positions.gpu[token_indices] if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) else: @@ -2821,7 +2831,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, - allow_microbatching: bool = False, + allow_microbatching: bool = True, skip_eplb: bool = False, is_profile: bool = False, create_mixed_batch: bool = False, @@ -2847,32 +2857,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - ubatch_enabled = self.parallel_config.enable_dbo - num_tokens_across_dp = None - num_pad = 0 - should_ubatch = False - if ubatch_enabled: - should_ubatch = num_tokens >= \ - self.parallel_config.dbo_decode_token_threshold and \ - allow_microbatching - - (should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch( - num_tokens, num_tokens, should_ubatch, self.vllm_config) - - # Currently the dummy run should only be ubatching during - # cuda graph capture, meaning all DP ranks should already - # have the same batch size - if num_tokens_across_dp is not None: - assert int(num_tokens_across_dp[0]) == num_tokens // 2 - assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - if not should_ubatch: - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens += num_pad - # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. @@ -2888,10 +2876,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens - if allow_microbatching: - assert self.uniform_decode_query_len == 1 - assert uniform_decode is True - assert max_query_len == 1 # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -2916,7 +2900,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert not create_mixed_batch num_reqs = cdiv(num_tokens, max_query_len) assert num_reqs <= max_num_reqs, \ - "Do not capture num_reqs > max_num_reqs for uniform batch" + f"Do not capture num_reqs {num_reqs} > max_num_reqs " \ + f"{max_num_reqs} for uniform batch. Num tokens: " \ + f"{num_tokens}, max_query_len: {max_query_len}" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len @@ -2930,20 +2916,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) ubatch_slices = None + num_tokens_after_padding = None + # We currently only microbatch if the number of tokens is # over a certain threshold. - if should_ubatch: - # We only support decode-only cudagraphs - assert num_reqs == num_tokens - assert num_tokens % 2 == 0 - ubatch_slices = [ - UBatchSlice(slice(0, num_reqs // 2), slice(0, - num_tokens // 2)), - UBatchSlice(slice(num_reqs // 2, num_reqs), - slice(num_tokens // 2, num_tokens)) - ] + if self.parallel_config.enable_dbo and allow_microbatching: + ubatch_slices, num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + total_num_scheduled_tokens, + total_num_scheduled_tokens, + self.vllm_config, + ) + + # If we failed to microbatch, currently need to resynchronize + # TODO(lucas,sage): we should be able to avoid this second sync by + # refactoring `get_dp_padding_ubatch` and `get_dp_padding` into + # a single `coordinate_batch_across_dp` function. + if num_tokens_after_padding is None: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens_after_padding = num_tokens + num_pad + else: + num_tokens_across_dp = num_tokens_after_padding + num_tokens_after_padding = int(num_tokens_after_padding[0].item()) attn_metadata: Optional[PerLayerAttnMetadata] = None @@ -2960,12 +2957,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO(luka) better system for describing dummy batches seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] else: - # Make sure max_model_len is used at the graph capture time. - seq_lens = self.max_model_len + seq_lens = max_query_len self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() + cum_num_tokens, _ = self._get_cumsum_and_arange( + num_scheduled_tokens) + self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): common_attn_metadata = CommonAttentionMetadata( @@ -3060,7 +3061,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, - num_tokens=num_tokens, + num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, @@ -3395,56 +3396,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", cudagraph_runtime_mode.name)) - enable_dbo = self.parallel_config.enable_dbo - # DBO Only supports running Full cudagraphs with uniform - # decode lengths - if enable_dbo and uniform_decode: - for num_tokens in compilation_cases: - # If the number of tokens is greater than the microbatching - # threshold, don't generate a microbatched cudagraph - if (num_tokens - < self.parallel_config.dbo_decode_token_threshold): - continue - # Warmup + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + # We currently only capture ubatched graphs when its a FULL + # cudagraph and for uniform decode batches. + capture_ubatched_graph = self.parallel_config.enable_dbo \ + and cudagraph_runtime_mode == CUDAGraphMode.FULL \ + and uniform_decode \ + and check_ubatch_thresholds( + config=self.vllm_config.parallel_config, + num_tokens=num_tokens, + uniform_decode=uniform_decode, + ) + + # Currently we capture both microbatched and non-microbatched + # graphs when capture_ubatched_graph is True, this is because + # occasionally we will be forced out of microbatching due to other + # DP ranks not microbatching (usually caused by an empty second + # microbatch; once we resolve this, we can remove the + # non-microbatched graph capture). + allow_microbatching_options = [True, False] if \ + capture_ubatched_graph else [False] + for allow_microbatching in allow_microbatching_options: for _ in range( self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. force_attention = ( cudagraph_runtime_mode == CUDAGraphMode.FULL) self._dummy_run(num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, - uniform_decode=True, - allow_microbatching=True, - skip_eplb=True) - - # Graph Capture + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False) self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, - allow_microbatching=True, - skip_eplb=True) - # We skip EPLB here since we don't want to record dummy metrics - for num_tokens in compilation_cases: - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, + cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - skip_eplb=True, - remove_lora=False) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -3500,24 +3496,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): - attn_metadata_builders = [] - attn_metadata_builders.append(attn_backend.get_builder_cls()( - kv_cache_spec, + attn_group = AttentionGroup.create_with_metadata_builders( + attn_backend, layer_names, + kv_cache_spec, self.vllm_config, self.device, - )) - if self.parallel_config.enable_dbo: - attn_metadata_builders.append( - attn_backend.get_builder_cls()( - kv_cache_spec, - layer_names, - self.vllm_config, - self.device, - )) - attn_group = AttentionGroup(attn_backend, - attn_metadata_builders, - layer_names, kv_cache_spec) + num_metadata_builders=1 + if not self.parallel_config.enable_dbo else 2, + ) + attn_groups.append(attn_group) return attn_groups @@ -3562,6 +3550,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): CUDAGraphMode.FULL_DECODE_ONLY logger.warning(msg) + # check that if we are doing decode full-cudagraphs it is supported + if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER): + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})") + if (self.compilation_config.level == CompilationLevel.PIECEWISE and + (self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition)): + msg += "; setting cudagraph_mode=PIECEWISE because "\ + "attention is compiled piecewise" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE because "\ + "attention is not compiled piecewise" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) + # check that if we are doing spec-decode + decode full-cudagraphs it is # supported if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 5012ad0483c84..d636e7af72ea1 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -1,25 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import threading +from dataclasses import dataclass from typing import Any, Callable, Optional import torch +import vllm.envs as envs from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed import get_ep_group +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id) from vllm.forward_context import (create_forward_context, get_forward_context, override_forward_context) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import has_deep_gemm from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts logger = init_logger(__name__) -@dataclasses.dataclass +@dataclass class UbatchMetadata: context: UBatchContext input_ids: torch.Tensor @@ -29,13 +34,55 @@ class UbatchMetadata: num_tokens: int -@dataclasses.dataclass +@dataclass class CUDAGraphMetaData: cudagraph: torch.cuda.CUDAGraph ubatch_metadata: UbatchMetadata outputs: Optional[Any] = None +class SMControlContextManager: + + def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None], + set_compute_sms: Callable[[int], None]): + """ + Context manager for controlling SM (Streaming Multiprocessor) + allocation. Upon entering the context, it sets the number of SMs + allocated for communication and computation to comm_sms and + total_sms - comm_sms respectively. Upon exiting, it restores the + allocation to use all available SMs (i.e. total_sms). + + Args: + comm_sms (int): The number of SMs to allocate for communication. + (The remainder will be used for computation.) + set_comm_sms (Callable[[int], None]): + A function that sets the number of SMs for communication. + set_compute_sms (Callable[[int], None]): + A function that sets the number of SMs for computation. + """ + + assert current_platform.is_cuda(), \ + "SM control is currently only supported on CUDA" + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + total_sms = props.multi_processor_count + + assert comm_sms < total_sms + self.total_sms = total_sms + self.compute_sms = total_sms - comm_sms + self.comm_sms = comm_sms + self.set_comm_sms = set_comm_sms + self.set_compute_sms = set_compute_sms + + def __enter__(self): + self.set_comm_sms(self.comm_sms) + self.set_compute_sms(self.compute_sms) + + def __exit__(self, exc_type, exc_value, traceback): + self.set_comm_sms(self.total_sms) + self.set_compute_sms(self.total_sms) + + class UBatchWrapper: def __init__(self, runnable: Callable, vllm_config: VllmConfig, @@ -56,6 +103,35 @@ class UBatchWrapper: runnable, vllm_config, runtime_mode=runtime_mode) self.graph_pool = current_platform.get_global_graph_pool() + self.sm_control = self._create_sm_control_context(vllm_config) + + @staticmethod + def _create_sm_control_context(vllm_config: VllmConfig): + comm_sms = envs.VLLM_DBO_COMM_SMS + + set_comm_sms = lambda sms: None + if vllm_config.parallel_config.enable_expert_parallel: + # Currently only DeepEP highthroughput supports SM control so this + # only affects that case. + all2all_manager = get_ep_group( + ).device_communicator.all2all_manager + + if all2all_manager.max_sms_used() is not None: + comm_sms = min(comm_sms, all2all_manager.max_sms_used()) + + if comm_sms > 0: + set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms) + + # TODO(lucas): support other kernels besides DeepGEMM + set_compute_sms = lambda sms: None + if has_deep_gemm() and comm_sms > 0: + import deep_gemm as dg + set_compute_sms = lambda sms: dg.set_num_sms(sms) + + return SMControlContextManager(comm_sms=comm_sms, + set_comm_sms=set_comm_sms, + set_compute_sms=set_compute_sms) + def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): @@ -132,6 +208,10 @@ class UBatchWrapper: cudagraph=torch.cuda.CUDAGraph(), ubatch_metadata=ubatch_metadata, ) + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) with torch.cuda.graph(cudagraph_metadata.cudagraph, stream=compute_stream, pool=self.graph_pool): @@ -282,8 +362,8 @@ class UBatchWrapper: dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE) - - return self._capture_ubatches(ubatch_metadata, self.model) + with self.sm_control: + return self._capture_ubatches(ubatch_metadata, self.model) elif num_tokens in self.cudagraphs: cudagraph_metadata = self.cudagraphs[num_tokens] cudagraph_metadata.cudagraph.replay() @@ -300,4 +380,5 @@ class UBatchWrapper: dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE) - return self._run_ubatches(ubatch_metadata, self.model) + with self.sm_control: + return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ec58fa43099c3..0b40ac6a7d629 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -386,11 +386,13 @@ class Worker(WorkerBase): f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " f"config with `--kv-cache-memory=" - f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " - f"requested memory, or `--kv-cache-memory=" - f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"{kv_cache_memory_bytes_to_requested_limit}` " + f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit " + f"into requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` " + f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully " f"utilize gpu memory. Current kv cache memory in use is " - f"{int(self.available_kv_cache_memory_bytes)} bytes.") + f"{GiB(self.available_kv_cache_memory_bytes)} GiB.") logger.debug(msg) diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py index 650f0ec5138db..30acb14ff58a7 100644 --- a/vllm/v1/worker/ubatch_splitting.py +++ b/vllm/v1/worker/ubatch_splitting.py @@ -3,9 +3,10 @@ from typing import Optional +import numpy as np import torch -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.forward_context import DPMetadata from vllm.logger import init_logger from vllm.utils import round_up @@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens( dp_size, dp_rank) +def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int, + uniform_decode: bool) -> bool: + if not config.enable_dbo: + return False + if uniform_decode: + return num_tokens >= config.dbo_decode_token_threshold + else: + return num_tokens >= config.dbo_prefill_token_threshold + + def get_dp_padding_ubatch( num_tokens_unpadded: int, num_tokens_padded: int, should_attempt_ubatching: bool, @@ -95,9 +106,37 @@ def get_dp_padding_ubatch( dtype=torch.int32) return should_ubatch, num_tokens_after_padding +def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \ + -> UBatchSlices: + # TODO(lucas): Refactor the gpu_model_runner.py so we can pass + # in cu_num_tokens directly (i.e. query_start_loc) + cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) + np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) + + first_ubatch_token_slice = slice(0, split_point) + second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + + # Determine request slices using exclusive stop semantics + # First ubatch includes requests whose tokens overlap [0, split_point) + first_ubatch_req_stop = int( + np.searchsorted(cu_num_tokens, split_point, side="left")) + first_ubatch_req_slice = slice(0, first_ubatch_req_stop) + + # Second ubatch starts at the request that contains the split_point + # or the request starting exactly at split_point (if on boundary) + second_ubatch_req_start = int( + np.searchsorted(cu_num_tokens, split_point, side="right") - 1) + second_ubatch_req_slice = slice(second_ubatch_req_start, + len(cu_num_tokens) - 1) + + return [ + UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), + UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice) + ] + def ubatch_split( - max_num_scheduled_tokens: int, + num_scheduled_tokens_per_request: np.ndarray, num_tokens_unpadded: int, num_tokens_padded: int, vllm_config: VllmConfig, @@ -122,17 +161,20 @@ def ubatch_split( return (None, None) # Check preconditions for microbatching - should_attempt_ubatching = \ - parallel_config.enable_dbo and \ - num_tokens_unpadded >= \ - parallel_config.dbo_decode_token_threshold \ - and max_num_scheduled_tokens == 1 + should_attempt_ubatching = check_ubatch_thresholds( + parallel_config, + num_tokens_unpadded, + vllm_config, + ) # Don't microbatch unless every other DP worker is also microbatching - num_tokens_after_padding = None - (should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch( - num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, - vllm_config) + should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch( + num_tokens_unpadded, + num_tokens_padded, + should_attempt_ubatching, + vllm_config, + ) + if not should_ubatch: return (None, None) @@ -141,15 +183,9 @@ def ubatch_split( # to the second ubatch in pad_out_ubatch_slice after attention # metadata creation assert num_tokens_after_padding is not None - total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item()) - padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) - padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, - num_tokens_unpadded) + token_split_point = int(num_tokens_after_padding[0].item()) - # Note there's an assumption here that there's 1 token per request - ubatch_slices = [ - UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice), - UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice) - ] + ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request, + token_split_point) return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 6716d171cc701..33d58aa948434 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -10,6 +10,14 @@ class UBatchSlice: request_slice: slice token_slice: slice + def is_empty(self) -> bool: + return self.request_slice.start == self.request_slice.stop \ + or self.token_slice.start == self.token_slice.stop + + @property + def num_tokens(self) -> int: + return self.token_slice.stop - self.token_slice.start + UBatchSlices: TypeAlias = list[UBatchSlice] diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 9aeaa9909dc81..c26cb07123a53 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -51,8 +51,8 @@ class UBatchContext: self.cpu_wait_event.wait() self.cpu_wait_event.clear() self._restore_context() - # Assume we start on the compute stream - assert current_stream() == self.compute_stream + # Assume we want to start on the compute stream + self.update_stream(self.compute_stream) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -62,17 +62,15 @@ class UBatchContext: self.maybe_run_recv_hook() self.cpu_signal_event.set() self.cpu_wait_event.clear() - self.current_stream = self.compute_stream - torch.cuda.set_stream(self.current_stream) return False def _restore_context(self): forward_context._forward_context = self.forward_context - torch.cuda.set_stream(self.current_stream) def update_stream(self, stream): self.current_stream = stream - torch.cuda.set_stream(self.current_stream) + if current_stream() != self.current_stream: + torch.cuda.set_stream(self.current_stream) def _signal_comm_done(self): self.gpu_comm_done_event.record(self.comm_stream) @@ -99,9 +97,20 @@ class UBatchContext: self.cpu_wait_event.clear() self._restore_context() + def switch_to_comm(self): + self.update_stream(self.comm_stream) + + def switch_to_compute(self): + self.update_stream(self.compute_stream) + def switch_to_comm_sync(self): self._signal_compute_done() self.update_stream(self.comm_stream) + self._wait_compute_done() + + def switch_to_compute_sync(self): + self._signal_comm_done() + self.update_stream(self.compute_stream) self._wait_comm_done() def maybe_run_recv_hook(self): @@ -112,8 +121,7 @@ class UBatchContext: def yield_(self): self.current_stream = current_stream() self._cpu_yield() - if self.current_stream != current_stream(): - self.update_stream(self.current_stream) + self.update_stream(self.current_stream) def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream @@ -153,15 +161,20 @@ def _register_ubatch_function(func): return wrapper +dbo_maybe_run_recv_hook = _register_ubatch_function( + UBatchContext.maybe_run_recv_hook) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( UBatchContext.yield_and_switch_from_compute_to_comm) dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( UBatchContext.yield_and_switch_from_comm_to_compute) -dbo_yield = _register_ubatch_function(UBatchContext.yield_) -dbo_maybe_run_recv_hook = _register_ubatch_function( - UBatchContext.maybe_run_recv_hook) +dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm) +dbo_switch_to_compute = _register_ubatch_function( + UBatchContext.switch_to_compute) dbo_switch_to_comm_sync = _register_ubatch_function( UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute_sync = _register_ubatch_function( + UBatchContext.switch_to_compute_sync) def dbo_register_recv_hook(recv_hook): diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index af922f9979d1a..553d33e27203d 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -130,15 +130,32 @@ class MultiModalBudget: @dataclass class AttentionGroup: backend: type[AttentionBackend] + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] kv_cache_spec: KVCacheSpec + @staticmethod + def create_with_metadata_builders( + backend: type[AttentionBackend], + layer_names: list[str], + kv_cache_spec: KVCacheSpec, + vllm_config: VllmConfig, + device: torch.device, + num_metadata_builders: int = 1, + ) -> 'AttentionGroup': + metadata_builders = [ + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, + device) + for _ in range(num_metadata_builders) + ] + return AttentionGroup(backend, metadata_builders, layer_names, + kv_cache_spec) + def get_metadata_builder(self, - ubatch_id: Optional[int] = None - ) -> AttentionMetadataBuilder: - if ubatch_id is None: - return self.metadata_builders[0] + ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id return self.metadata_builders[ubatch_id]