Merge branch 'main' into woosuk/model-runner-v2

This commit is contained in:
Woosuk Kwon 2025-09-23 21:08:15 +00:00
commit 704def253c
74 changed files with 2380 additions and 673 deletions

View File

@ -164,6 +164,7 @@ steps:
- tests/v1/test_internal_lb_dp.py - tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py - tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
- tests/distributed/test_symm_mem_allreduce.py
commands: commands:
# test with torchrun tp=2 and external_dp=2 # test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
@ -188,6 +189,7 @@ steps:
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py - pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
- pushd ../examples/offline_inference - pushd ../examples/offline_inference
@ -329,6 +331,8 @@ steps:
- python3 offline_inference/basic/classify.py - python3 offline_inference/basic/classify.py
- python3 offline_inference/basic/embed.py - python3 offline_inference/basic/embed.py
- python3 offline_inference/basic/score.py - python3 offline_inference/basic/score.py
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
- label: Platform Tests (CUDA) # 4min - label: Platform Tests (CUDA) # 4min
timeout_in_minutes: 15 timeout_in_minutes: 15
@ -1037,3 +1041,4 @@ steps:
num_gpus: 2 num_gpus: 2
commands: commands:
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py

View File

@ -103,10 +103,15 @@ start_server() {
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \ VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
fi fi
local server_pid=$!
# wait for 10 minutes... # wait for 10 minutes...
server_started=0 server_started=0
for i in {1..60}; do for i in {1..60}; do
# This line checks whether the server is still alive or not,
# since that we should always have permission to send signal to the server process.
kill -0 $server_pid 2> /dev/null || break
RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout)
STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) STATUS_CODE=$(echo "$RESPONSE" | tail -n 1)
if [[ "$STATUS_CODE" -eq 200 ]]; then if [[ "$STATUS_CODE" -eq 200 ]]; then
@ -118,7 +123,7 @@ start_server() {
done done
if (( ! server_started )); then if (( ! server_started )); then
echo "server did not start within 10 minutes. Please check server log at $vllm_log". echo "server did not start within 10 minutes or crashed. Please check server log at $vllm_log".
return 1 return 1
else else
return 0 return 0

View File

@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul, w8a8_triton_block_scaled_mm,
) )
from vllm.utils import FlexibleArgumentParser, cdiv from vllm.utils import FlexibleArgumentParser, cdiv
@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
), ),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
), ),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(

View File

@ -7,6 +7,10 @@ Benchmark script for device communicators:
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
and SymmMemCommunicator (multimem, two-shot). and SymmMemCommunicator (multimem, two-shot).
for NCCL symmetric memory you need to set the environment variables
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
not use fast NVLS implementation for all reduce.
Usage: Usage:
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options] torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
@ -26,7 +30,13 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
register_nccl_symmetric_ops,
)
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id,
)
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -98,6 +108,7 @@ class CommunicatorBenchmark:
) )
if not self.pynccl_comm.disabled: if not self.pynccl_comm.disabled:
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
register_nccl_symmetric_ops(self.pynccl_comm)
else: else:
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
self.pynccl_comm = None self.pynccl_comm = None
@ -194,6 +205,15 @@ class CommunicatorBenchmark:
None, # no env variable needed None, # no env variable needed
) )
) )
communicators.append(
(
"pynccl-symm",
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
)
)
if self.symm_mem_comm_multimem is not None: if self.symm_mem_comm_multimem is not None:
comm = self.symm_mem_comm_multimem comm = self.symm_mem_comm_multimem
@ -271,7 +291,9 @@ class CommunicatorBenchmark:
# Capture the graph using context manager # Capture the graph using context manager
with context: with context:
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input) allreduce_fn(graph_input)

View File

@ -9,6 +9,9 @@ import torch
from tabulate import tabulate from tabulate import tabulate
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import ( from vllm.utils import (
@ -31,6 +34,8 @@ def run_benchmark(
kv_cache_dtype: str, kv_cache_dtype: str,
kv_cache_layout: str, kv_cache_layout: str,
num_iters: int, num_iters: int,
implementation: str,
benchmark_mode: str,
device: str = "cuda", device: str = "cuda",
) -> float: ) -> float:
"""Return latency (seconds) for given num_tokens.""" """Return latency (seconds) for given num_tokens."""
@ -38,6 +43,14 @@ def run_benchmark(
if kv_cache_dtype == "fp8" and head_size % 16: if kv_cache_dtype == "fp8" and head_size % 16:
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
if implementation not in ("cuda", "triton"):
raise ValueError(
f"Unsupported implementation: {implementation}. "
"Only 'cuda' and 'triton' are supported."
)
if implementation == "triton" and kv_cache_layout == "HND":
return float("nan") # Triton does not support HND layout yet.
current_platform.seed_everything(42) current_platform.seed_everything(42)
torch.set_default_device(device) torch.set_default_device(device)
@ -65,27 +78,49 @@ def run_benchmark(
cache_layout=kv_cache_layout, cache_layout=kv_cache_layout,
) )
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# to free unused memory
del key_caches, value_caches
# compute per-kernel scaling factors for fp8 conversion (if used). # compute per-kernel scaling factors for fp8 conversion (if used).
k_scale = (key.amax() / 64.0).to(torch.float32) k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32)
if implementation == "cuda":
function_under_test = lambda: ops.reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
else:
function_under_test = lambda: triton_reshape_and_cache_flash(
key, # noqa: F821
value, # noqa: F821
key_cache, # noqa: F821
value_cache, # noqa: F821
slot_mapping, # noqa: F821
kv_cache_dtype,
k_scale,
v_scale,
)
if benchmark_mode == "cudagraph":
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
function_under_test()
torch.cuda.synchronize()
function_under_test = lambda: g.replay()
def run_cuda_benchmark(n_iters: int) -> float: def run_cuda_benchmark(n_iters: int) -> float:
nonlocal key, value, key_cache, value_cache, slot_mapping nonlocal key, value, key_cache, value_cache, slot_mapping
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.perf_counter() start = time.perf_counter()
for _ in range(n_iters): for _ in range(n_iters):
ops.reshape_and_cache_flash( function_under_test()
key, torch.cuda.synchronize()
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
torch.cuda.synchronize()
end = time.perf_counter() end = time.perf_counter()
return (end - start) / n_iters return (end - start) / n_iters
@ -116,10 +151,16 @@ def main(args):
kv_cache_dtype=args.kv_cache_dtype, kv_cache_dtype=args.kv_cache_dtype,
kv_cache_layout=layout, kv_cache_layout=layout,
num_iters=args.iters, num_iters=args.iters,
implementation=args.implementation,
benchmark_mode=args.mode,
device="cuda", device="cuda",
) )
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"]) rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
print(
f"Benchmark results for implementation {args.implementation}"
f" (measuring with {args.mode}):"
)
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"])) print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
@ -151,6 +192,21 @@ if __name__ == "__main__":
) )
parser.add_argument("--iters", type=int, default=100) parser.add_argument("--iters", type=int, default=100)
parser.add_argument(
"--implementation",
type=str,
choices=["cuda", "triton"],
default="cuda",
)
parser.add_argument(
"--mode",
type=str,
choices=["cudagraph", "no_graph"],
default="cudagraph",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_block_fp8_matmul, w8a8_triton_block_scaled_mm,
) )
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
@ -59,7 +59,7 @@ def benchmark_shape(m: int,
# === vLLM Triton Implementation === # === vLLM Triton Implementation ===
def vllm_triton_gemm(): def vllm_triton_gemm():
return w8a8_block_fp8_matmul(A_vllm, return w8a8_triton_block_scaled_mm(A_vllm,
B_vllm, B_vllm,
A_scale_vllm, A_scale_vllm,
B_scale_vllm, B_scale_vllm,

View File

@ -418,6 +418,15 @@ __device__ inline T neg_inf() {
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity()); return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
} }
template <typename T>
__device__ inline bool is_finite(const T val) {
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
return cuda::std::isfinite(val);
#else
return isfinite(cuda_cast<float, T>(val));
#endif
}
template <typename T> template <typename T>
__device__ void topk_with_k2(T* output, T const* input, __device__ void topk_with_k2(T* output, T const* input,
cg::thread_block_tile<32> const& tile, cg::thread_block_tile<32> const& tile,
@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel(
// calculate group_idx // calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group; int32_t target_num_min = WARP_SIZE - n_group + topk_group;
// The check is necessary to avoid abnormal input // The check is necessary to avoid abnormal input
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) { if (lane_id < n_group && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id]; value = group_scores[lane_id];
} }
@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group; int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) { i += WARP_SIZE) {
T candidates = T candidates = (i < num_experts_per_group) &&
(i < num_experts_per_group) && is_finite(scores_with_bias[offset + i])
cuda::std::isfinite(scores_with_bias[offset + i]) ? scores_with_bias[offset + i]
? scores_with_bias[offset + i] : neg_inf<T>();
: neg_inf<T>();
queue.add(candidates, offset + i); queue.add(candidates, offset + i);
} }
if (group_scores[i_group] == topk_group_value) { if (group_scores[i_group] == topk_group_value) {

View File

@ -12,8 +12,8 @@
#include "../vectorization_utils.cuh" #include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h" #include "../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { __device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = 0xffff; unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
threads_per_group, // stride in group threads_per_group, // stride in group
scalar_op_cache); // scalar handler scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax, lane_id); local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit; float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) { if constexpr (SCALE_UE8M0) {

View File

@ -5,11 +5,14 @@
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block); const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias,
const int64_t CuCount); const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount);
void paged_attention( void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,

View File

@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
const scalar_t* __restrict__ A, scalar_t* C, const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__) #if defined(__HIP__MI3XX__)
@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]); C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
} }
} }
@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
} }
} }
@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
const scalar_t* __restrict__ A, scalar_t* C, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B, wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
const scalar_t* __restrict__ A, scalar_t* C, const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__) #if defined(__HIP__MI3XX__)
@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) if (commitColumn[i]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]); C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
} }
} }
} }
@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); if (commitColumn[i]) {
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
} }
} }
} }
@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B, __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
const scalar_t* __restrict__ A, scalar_t* C, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
const scalar_t* __restrict__ A, scalar_t* C, const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2; constexpr int max_lds_len = LDS_SIZE / 2;
#if defined(__HIP__MI3XX__) #if defined(__HIP__MI3XX__)
@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) if (commitColumn[i]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]); C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
} }
} }
} }
@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 63) { if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) { for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); if (commitColumn[i]) {
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
} }
} }
} }
@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support #else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK, template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N> int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
const scalar_t* __restrict__ A, scalar_t* C, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) {
return rtn; return rtn;
} }
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias,
const int64_t CuCount) { const int64_t CuCount) {
auto M_in = in_a.size(0); auto M_in = in_a.size(0);
auto K_in = in_a.size(1); auto K_in = in_a.size(1);
auto N_in = in_b.size(0); auto N_in = in_b.size(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
: 1;
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
in_bias->sizes().size() == 2)
? in_bias->size(0)
: 1;
TORCH_CHECK(in_a.dtype() == in_b.dtype()); TORCH_CHECK(in_a.dtype() == in_b.dtype());
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
CuCount); \ biasf4, c, __wvPrGrp, CuCount); \
} else if (K_in * N_in <= max_lds_len * 1.2) { \ } else if (K_in * N_in <= max_lds_len * 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
CuCount); \ biasf4, c, __wvPrGrp, CuCount); \
} else { \ } else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \ <<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
CuCount); \ biasf4, c, __wvPrGrp, CuCount); \
} \ } \
} }
@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
using fptype = typename scalar<scalar_t>::type; using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr()); fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr()); const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
const fptype* biasf4 =
(in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr()); fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
switch (N_in) { switch (N_in) {
case 1: case 1:
@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B, wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
const fp8_t* __restrict__ A, scalar_t* C, const int By, const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp, const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) { const int CuCount) {
@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB); if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
} }
} }
} }
@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A, const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS,
scalar_t* C, const float* __restrict__ s_A, scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS) __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B, wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
const fp8_t* __restrict__ A, scalar_t* C, const int By, const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B, const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) { const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE; constexpr int max_lds_len = LDS_SIZE;
@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) { for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault. if (y + m >= M) break; // To avoid mem access fault.
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB); sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
} }
} }
} }
@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp, template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N> int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const fp8_t* B, const fp8_t* __restrict__ A, const int Bx, const int By, const fp8_t* B,
scalar_t* C, const float* __restrict__ s_A, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp, const float* __restrict__ s_B, const int _WvPrGrp,
const int CuCount) { const int CuCount) {
UNREACHABLE_CODE UNREACHABLE_CODE
} }
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
at::Tensor& scale_a, at::Tensor& scale_b, const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount) { const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp() static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn ? c10::ScalarType::Float8_e4m3fn
@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
auto K_in = in_a.size(1); auto K_in = in_a.size(1);
auto N_in = in_b.size(0); auto N_in = in_b.size(0);
auto Kp_in = in_a.stride(0); auto Kp_in = in_a.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
: 1;
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
in_bias->sizes().size() == 2)
? in_bias->size(0)
: 1;
TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0"); TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type); TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
TORCH_CHECK(out_c.dtype() == torch::kFloat16 || TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \ wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ <<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \ b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} else { \ } else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \ wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \ <<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \ b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} \ } \
} }
@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] { VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
auto a_ptr = in_a.data_ptr<fp8_t>(); auto a_ptr = in_a.data_ptr<fp8_t>();
auto b_ptr = in_b.data_ptr<fp8_t>(); auto b_ptr = in_b.data_ptr<fp8_t>();
auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<fptype*>(in_bias->data_ptr())
: nullptr;
switch (N_in) { switch (N_in) {
case 1: case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1) WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)

View File

@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
// Custom gemm op for skinny matrix-matrix multiplication // Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def( rocm_ops.def(
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> " "wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
"Tensor"); "Tensor");
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
// wvSplitK for fp8 // wvSplitK for fp8
rocm_ops.def( rocm_ops.def(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, " "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
"Tensor scale_a, "
" Tensor scale_b, int CuCount) -> ()"); " Tensor scale_b, int CuCount) -> ()");
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);

View File

@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts):
def parse_args(): def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
add_dataset_parser(parser) add_dataset_parser(parser)
parser.add_argument("--test", action="store_true")
parser.add_argument( parser.add_argument(
"--method", "--method",
type=str, type=str,
@ -60,6 +61,7 @@ def parse_args():
parser.add_argument("--tp", type=int, default=1) parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enforce-eager", action="store_true")
parser.add_argument("--enable-chunked-prefill", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true")
parser.add_argument("--max-model-len", type=int, default=16384)
parser.add_argument("--temp", type=float, default=0) parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=-1) parser.add_argument("--top-k", type=int, default=-1)
@ -71,8 +73,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main(args):
args = parse_args()
args.endpoint_type = "openai-chat" args.endpoint_type = "openai-chat"
model_dir = args.model_dir model_dir = args.model_dir
@ -134,7 +135,7 @@ def main():
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
speculative_config=speculative_config, speculative_config=speculative_config,
disable_log_stats=False, disable_log_stats=False,
max_model_len=16384, max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5}, limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True, disable_chunked_mm_input=True,
) )
@ -198,6 +199,39 @@ def main():
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
print(f"acceptance at token {i}: {acceptance_rate:.2f}") print(f"acceptance at token {i}: {acceptance_rate:.2f}")
return acceptance_length
if __name__ == "__main__": if __name__ == "__main__":
main() args = parse_args()
acceptance_length = main(args)
if args.test:
# takes ~30s to run on 1xH100
assert args.method in ["eagle", "eagle3"]
assert args.tp == 1
assert args.num_spec_tokens == 3
assert args.dataset_name == "hf"
assert args.dataset_path == "philschmid/mt-bench"
assert args.num_prompts == 80
assert args.temp == 0
assert args.top_p == 1.0
assert args.top_k == -1
assert args.enable_chunked_prefill
# check acceptance length is within 2% of expected value
rtol = 0.02
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811
assert (
acceptance_length <= (1 + rtol) * expected_acceptance_length
and acceptance_length >= (1 - rtol) * expected_acceptance_length
), (
f"acceptance_length {acceptance_length} is not "
f"within {rtol * 100}% of {expected_acceptance_length}"
)
print(
f"Test passed! Expected AL: "
f"{expected_acceptance_length}, got {acceptance_length}"
)

View File

@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import typing
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
from vllm.distributed.device_communicators.pynccl_allocator import (
get_nccl_mem_pool, is_symmetric_memory_enabled)
from vllm.distributed.parallel_state import (get_tp_group,
init_distributed_environment,
initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
torch.manual_seed(42)
random.seed(44)
test_size_elements = 4 * 1024 * 1024
def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator,
get_tp_group().device_communicator)
pynccl_comm = cuda_communicator.pynccl_comm
if get_nccl_mem_pool() is None:
pytest.skip("NCCL allocator compilation failed "
"(probably missing NCCL headers).")
if not is_symmetric_memory_enabled():
pytest.skip("NCCL symmetric memory allreduce is disabled.")
register_nccl_symmetric_ops(pynccl_comm)
input = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
input_clone = input.clone()
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
assert output is not None
group = get_tp_group().device_group
dist.all_reduce(input_clone, group=group)
torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
)
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1")
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
mp.spawn(nccl_symm_mem_allreduce_worker,
args=(world_size, ),
nprocs=world_size)
cleanup_dist_env_and_memory()

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import queue
import random import random
import typing import typing
@ -10,26 +11,31 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import ( from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator) CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, from vllm.distributed.parallel_state import (get_tp_group,
get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
torch.manual_seed(42) torch.manual_seed(42)
random.seed(44) random.seed(44)
test_size_elements = 4 * 1024 * 1024 test_size_elements = 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int): def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
monkeypatch = pytest.MonkeyPatch() monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m: config = VllmConfig(parallel_config=ParallelConfig(
tensor_parallel_size=world_size))
with monkeypatch.context() as m, set_current_vllm_config(config):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16 dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
get_tp_group().device_communicator) get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled: if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.") # can't use skip under multiprocessing
q.put("SymmMemCommunicator is not available or disabled.")
return
inp_direct_symm_mem = torch.randint(1, inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ), 23, (test_size_elements, ),
dtype=dtype, dtype=dtype,
device=device) device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip( # can't use skip under multiprocessing
q.put(
"SymmMemCommunicator isn't used for this world and input size." "SymmMemCommunicator isn't used for this world and input size."
) )
return
original_inp_direct_symm_mem = inp_direct_symm_mem.clone() original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None assert out_direct_symm_mem is not None
group = get_tensor_model_parallel_group().device_group group = get_tp_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group) dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem, torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem, original_inp_direct_symm_mem,
@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
q = mp.get_context('spawn').Queue()
mp.spawn(symm_mem_allreduce_worker,
args=(world_size, q),
nprocs=world_size)
try:
val = q.get(timeout=1)
except queue.Empty:
val = None
finally:
cleanup_dist_env_and_memory()
if val is not None:
pytest.skip(val)
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) @pytest.mark.skipif(
cleanup_dist_env_and_memory() not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
world_size = 4
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Verify that the DataParallel runs without error
engine_args = EngineArgs(model="distilbert/distilgpt2",
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=2,
tensor_parallel_size=2,
data_parallel_backend="mp")
LLMEngine.from_engine_args(engine_args)

View File

@ -39,6 +39,8 @@ CUDA_DEVICES = [
# We assume fp8 is always enabled for testing. # We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE = ["auto", "fp8"] KV_CACHE_DTYPE = ["auto", "fp8"]
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS) @pytest.mark.parametrize("num_layers", NUM_LAYERS)
@ -223,6 +225,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS) @pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode() @torch.inference_mode()
def test_reshape_and_cache_flash( def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer, kv_cache_factory_flashinfer,
@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_cache_layout: str, kv_cache_layout: str,
implementation: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
assert implementation in ["cuda", "triton"]
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
# fp8 conversion requires continugous memory buffer. Reduce the number of # fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory. # blocks and tokens to consume less memory.
@ -298,12 +305,20 @@ def test_reshape_and_cache_flash(
cloned_key_cache = key_cache_compact.clone() cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone() cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, if implementation == "cuda":
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
k_scale, v_scale), (key, value, key_cache, value_cache, slot_mapping,
cond=(head_size == HEAD_SIZES[0])) kv_cache_dtype, k_scale, v_scale),
ops.reshape_and_cache_flash(key, value, key_cache, value_cache, cond=(head_size == HEAD_SIZES[0]))
slot_mapping, kv_cache_dtype, k_scale, v_scale) ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale,
v_scale)
key_cache_compact = permute_and_compact(key_cache) key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache) value_cache_compact = permute_and_compact(value_cache)

View File

@ -12,7 +12,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, get_col_major_tma_aligned_tensor, cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype) out_dtype)
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
rel_diff = (torch.mean( rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /

View File

@ -20,9 +20,11 @@ from vllm.platforms import current_platform
(8, 513, 64), # Non-divisible (native only) (8, 513, 64), # Non-divisible (native only)
]) ])
@pytest.mark.parametrize("seed", [42]) @pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None: group_size: int, seed: int,
use_ue8m0: bool) -> None:
"""Test QuantFP8 group quantization with various configurations. """Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales, Tests both CUDA and native implementations, column-major scales,
@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size) group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False, quant_op = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=False) column_major_scales=False,
use_ue8m0=use_ue8m0)
# 1. Test native implementation (always available) # 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone()) x_quant_native, scales_native = quant_op.forward_native(x.clone())
@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration # 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False, quant_op_col = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=True) column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x.clone()) _, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size) assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size
# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
# 3. Test CUDA implementation (only for divisible dimensions) # 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible: if is_divisible:
@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
@pytest.mark.parametrize("seed", [42]) @pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None: def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
group_size = 64 group_size = 64
@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
group_shape = GroupShape(1, group_size) group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False, quant_op = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=False) column_major_scales=False,
use_ue8m0=use_ue8m0)
x_quant, scales = quant_op.forward_native(x_3d.clone()) x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape assert x_quant.shape == x_3d.shape
@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
# Test column_major_scales with multi-dim # Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False, quant_op_col = QuantFP8(static=False,
group_shape=group_shape, group_shape=group_shape,
column_major_scales=True) column_major_scales=True,
use_ue8m0=use_ue8m0)
_, scales_col = quant_op_col.forward_native(x_3d.clone()) _, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

View File

@ -1,12 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest import pytest
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
rocm_per_tensor_w8a8_scaled_mm_impl)
from vllm.platforms import current_platform from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
@ -49,6 +49,7 @@ NKM_FACTORS_WVSPLITK_FP8 = [
(2, 512, 512), (2, 512, 512),
(3, 2048, 2048), (3, 2048, 2048),
(4, 4096, 4096), (4, 4096, 4096),
(4, 16400, 2048),
# Extended FP8 dimensions not covered by WVSPLITK # Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024), (1, 14336, 1024),
(2, 24576, 2048), (2, 24576, 2048),
@ -67,6 +68,9 @@ SEEDS = [0]
@torch.inference_mode() @torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
#TODO: Zero-centering the inputs causes errors for LLMM1!
# Without that the numbers quickly saturate, and may
# be giving false matches.
A = torch.rand(n, k, dtype=dtype, device="cuda") A = torch.rand(n, k, dtype=dtype, device="cuda")
B = torch.rand(m, k, dtype=dtype, device="cuda") B = torch.rand(m, k, dtype=dtype, device="cuda")
@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = current_platform.get_cu_count() cu_count = current_platform.get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda") A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
B = torch.rand(m, k, dtype=dtype, device="cuda") B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
ref_out = torch.matmul(A, B.t()) ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A, cu_count) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01) assert torch.allclose(out, ref_out, rtol=0.01)
@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda") A = torch.rand(n, k, device="cuda") - 0.5
B = torch.rand(m, k, device="cuda") B = torch.rand(m, k, device="cuda") - 0.5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif( @pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()), not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8") reason="only test for rocm fp8")
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias): def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda") xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
B = torch.rand(m, k, device="cuda") A = (torch.rand(n, k, device="cuda") - .5) * xavier
B = (torch.rand(m, k, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None
output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
scale_b, bias)
ref_out = torch._scaled_mm(A, ref_out = torch._scaled_mm(A,
B.t(), B.t(),
out_dtype=dtype, out_dtype=dtype,
scale_a=scale_a, scale_a=scale_a,
scale_b=scale_b, scale_b=scale_b,
bias=bias) bias=BIAS)
assert torch.allclose(output, ref_out, rtol=0.01) out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count(), BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)

View File

@ -17,8 +17,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.layernorm import (RMSNorm, from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func, dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm) fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
@pytest.mark.skipif(
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_gemm_w8a8_blockscale: str,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
use_rocm_aiter_gemm_w8a8_blockscale)
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
int(use_rocm_aiter_gemm_w8a8_blockscale)))
block_scale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_gemm_w8a8_blockscale):
assert block_scale_func == (
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
else:
assert block_scale_func == w8a8_block_fp8_matmul
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

View File

@ -18,6 +18,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported) cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
perplexity = llm.generate_prompt_perplexity([prompt])[0] perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity) print(perplexity)
assert perplexity <= exp_perplexity assert perplexity <= exp_perplexity
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
with vllm_runner(model_path) as llm:
fp8_dtype = current_platform.fp8_dtype()
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp)
assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2
input_quant_op = \
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method == input_quant_op.forward_cuda
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -5,11 +5,12 @@ import pytest
import torch import torch
from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UBatchSlice, from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice, _make_metadata_with_slice,
slice_query_start_locs, slice_query_start_locs,
split_attn_metadata) split_attn_metadata)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
@pytest.fixture @pytest.fixture
@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert results[1].num_reqs == mid_point assert results[1].num_reqs == mid_point
assert results[1].num_actual_tokens == mid_point assert results[1].num_actual_tokens == mid_point
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
@pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[
# Split in the middle of request 1
([32, 40], [8, 8], 12, 2, 1),
# Split inside the first request
([32, 40], [8, 8], 4, 1, 2),
],
)
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
expected_first_reqs,
expected_second_reqs):
"""Test splitting a prefill across ubatches"""
import numpy as np
device = torch.device("cpu")
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
common = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
qsl_np = common.query_start_loc_cpu.numpy()
num_tokens = common.num_actual_tokens
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
assert len(ubatch_slices) == 2
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
# Token counts match the split
assert first_meta.num_actual_tokens == split_point
assert second_meta.num_actual_tokens == num_tokens - split_point
# Number of requests per ubatch
assert first_meta.num_reqs == expected_first_reqs
assert second_meta.num_reqs == expected_second_reqs
# Identify which request is split and how many tokens are in the first chunk
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
orig_q_lens = (common.query_start_loc_cpu[1:] -
common.query_start_loc_cpu[:-1])
# Check query length continuity: first-chunk + second-chunk == original qlen
# First ubatch last request query length
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
first_meta.query_start_loc_cpu[-2])
# Second ubatch first request query length
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
second_meta.query_start_loc_cpu[0])
assert qlen_first_last == tokens_in_first_chunk
assert qlen_first_last + qlen_second_first == int(
orig_q_lens[split_req_idx])
# Check seq_lens adjustments
# Context lengths per original request
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
# First ubatch: last request's seq_len should be
# context + tokens_in_first_chunk
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
assert int(first_meta.seq_lens[-1]) == expected_seqlen
# For full preceding requests in first ubatch, seq_lens should match
# originals
for i in range(first_meta.num_reqs - 1):
assert int(first_meta.seq_lens[i]) == seq_lens[i]
# Second ubatch: first request (continuation) seq_len should be full
# original
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
# Any following full requests in second ubatch should match originals
for j in range(1, second_meta.num_reqs):
# Map to original request index
orig_idx = split_req_idx + j
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]

View File

@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from dataclasses import fields
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -21,7 +22,8 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import (GuidedDecodingParams, SamplingParams,
StructuredOutputsParams)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import TokenizerMode from vllm.config import TokenizerMode
@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str:
return json.loads(s) return json.loads(s)
def test_guided_decoding_deprecated():
with pytest.warns(DeprecationWarning,
match="GuidedDecodingParams is deprecated.*"):
guided_decoding = GuidedDecodingParams(json_object=True)
structured_outputs = StructuredOutputsParams(json_object=True)
assert fields(guided_decoding) == fields(structured_outputs)
with pytest.warns(DeprecationWarning,
match="guided_decoding is deprecated.*"):
sp1 = SamplingParams(guided_decoding=guided_decoding)
with pytest.warns(DeprecationWarning,
match="guided_decoding is deprecated.*"):
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
assert sp1 == sp2
assert sp1.structured_outputs == guided_decoding
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config", "model_name, backend, tokenizer_mode, speculative_config",

View File

@ -532,9 +532,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
# Mock runner for attention metadata building # Mock runner for attention metadata building
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [ proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder attn_metadata_builder
] proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
result = proposer.propose(target_token_ids=target_token_ids, result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
@ -659,9 +660,10 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building. # Mock runner for attention metadata building.
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [ proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder attn_metadata_builder
] proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder)
# Setup inputs for the proposer. # Setup inputs for the proposer.
target_token_ids = torch.randint(0, target_token_ids = torch.randint(0,

View File

@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor,
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: def wvSplitK(a: torch.Tensor,
return torch.ops._rocm_C.wvSplitK(a, b, cu_count) b: torch.Tensor,
cu_count: int,
bias: torch.Tensor = None) -> torch.Tensor:
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, def wvSplitKQ(a: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor, b: torch.Tensor,
cu_count: int) -> torch.Tensor: out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
cu_count: int,
bias: torch.Tensor = None) -> torch.Tensor:
out = torch.empty((b.shape[0], a.shape[0]), out = torch.empty((b.shape[0], a.shape[0]),
dtype=out_dtype, dtype=out_dtype,
device=b.device) device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out return out

View File

@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
@triton.jit
def reshape_and_cache_kernel_flash(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_key_idx = token_idx * key_stride
src_value_idx = token_idx * value_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride
# [TILE_SIZE]
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if key_load.dtype.is_fp8():
key_tile = key_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load / tl.load(k_scale)
else:
key_tile = key_load
# [TILE_SIZE]
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
mask=tile_pos < (num_heads * head_size))
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load
tl.store(
key_cache_ptr + tgt_idx + tile_pos,
key_tile,
mask=tile_pos < (num_heads * head_size),
)
tl.store(
value_cache_ptr + tgt_idx + tile_pos,
value_tile,
mask=tile_pos < (num_heads * head_size),
)
return
def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size]
# [num_blocks, block_size, num_heads, head_size]
key_cache: torch.Tensor,
# [num_blocks, block_size, num_heads, head_size]
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_tokens = key.shape[0]
num_heads = key.shape[1]
head_size = key.shape[2]
block_size = key_cache.shape[1]
n = num_heads * head_size
key_stride = key.stride()[0]
value_stride = value.stride()[0]
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]
head_stride = key_cache.stride()[2]
assert head_stride == head_size, "only continous heads are supported"
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
kv_cache_dtype.startswith("fp8") else key_cache.dtype
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
"fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
"uint8 is not supported by triton reshape_and_cache_flash"
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
torch.float8_e4m3fnuz], \
"unsupported dtype of KV cache tensor, got "\
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if torch.version.hip:
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
reshape_and_cache_kernel_flash[grid](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=key_stride,
value_stride=value_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size=head_size,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)

View File

@ -12,6 +12,8 @@ import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -154,6 +156,10 @@ class CUDAGraphWrapper:
stack.enter_context( stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None)) patch("torch.cuda.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory. # mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool): with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool # `output` is managed by pytorch's cudagraph pool

View File

@ -509,8 +509,15 @@ class VllmConfig:
if self.compilation_config.cudagraph_mode is None: if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \ if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE: == CompilationLevel.PIECEWISE:
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = \ self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE CUDAGraphMode.FULL_AND_PIECEWISE
# pooling model does not support full cudagraphs
if self.model_config is not None and \
self.model_config.pooler_config is not None:
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else: else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
@ -638,11 +645,13 @@ class VllmConfig:
if self.parallel_config.enable_dbo: if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND a2a_backend = envs.VLLM_ALL2ALL_BACKEND
assert a2a_backend == "deepep_low_latency", \ assert a2a_backend in \
"Microbatching currently only supports the deepep_low_latency "\ ["deepep_low_latency", "deepep_high_throughput"], \
f"all2all backend. {a2a_backend} is not supported. To fix set "\ "Microbatching currently only supports the deepep_low_latency and "\
"the VLLM_ALL2ALL_BACKEND environment variable to "\ f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
"deepep_low_latency and install the DeepEP kerenls." "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
"variable to deepep_low_latency or deepep_high_throughput and "\
"install the DeepEP kernels."
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
@ -685,6 +694,23 @@ class VllmConfig:
# local attention. # local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True self.scheduler_config.disable_hybrid_kv_cache_manager = True
def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
def update_sizes_for_sequence_parallelism(self, def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list: possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when

View File

@ -228,15 +228,14 @@ class CompilationConfig:
The mode of the cudagraph: The mode of the cudagraph:
- NONE, no cudagraph capture. - NONE, no cudagraph capture.
- PIECEWISE. (v1 default) - PIECEWISE.
- FULL. - FULL.
- FULL_DECODE_ONLY. - FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE. - FULL_AND_PIECEWISE. (v1 default)
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatible ops (i.e. some attention ops) outside the cudagraph incompatible ops (i.e. some attention ops) outside the cudagraph
for general flexibility. for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends. models or workloads with small prompts; not supported by many backends.
@ -249,7 +248,7 @@ class CompilationConfig:
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches. piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models. This is the most performant mode for most models and is the default.
Currently, the cudagraph mode is only used for the v1 engine. Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the Note that the cudagraph logic is generally orthogonal to the

View File

@ -1003,6 +1003,7 @@ class ModelConfig:
self.quantization = quantization_override self.quantization = quantization_override
break break
quant_method = quant_method if quant_method != "" else None
# Verify quantization configurations. # Verify quantization configurations.
if self.quantization is None: if self.quantization is None:
self.quantization = quant_method self.quantization = quant_method

View File

@ -139,12 +139,18 @@ class ParallelConfig:
"""Disable the custom all-reduce kernel and fall back to NCCL.""" """Disable the custom all-reduce kernel and fall back to NCCL."""
enable_dbo: bool = False enable_dbo: bool = False
"""Enable microbatching for the model executor.""" """Enable dual batch overlap for the model executor."""
dbo_decode_token_threshold: int = 32 dbo_decode_token_threshold: int = 32
"""The threshold for microbatching. If the number of tokens in the """The threshold for dual batch overlap for batches only containing decodes.
request is greater than this threshold, microbatching will be used. If the number of tokens in the request is greater than this threshold,
Otherwise, the request will be processed in a single batch.""" microbatching will be used. Otherwise, the request will be processed in a
single batch."""
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
"""The threshold for dual batch overlap for batches that contain one or more
prefills. If the number of tokens in the request is greater than this
threshold, microbatching will be used. Otherwise, the request will be
processed in a single batch."""
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def _make_all2all_kwargs(self) -> dict[Any, Any]: def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests. # Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024 num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_rdma_bytes = None num_rdma_bytes = None
num_qps_per_rank = None num_qps_per_rank = None
if self.internode: if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024 num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = self.num_sms // 2 num_qps_per_rank = self.num_sms // 2
else: else:
num_rdma_bytes = 0 num_rdma_bytes = 0
@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
logger.debug("DeepEP all2all args %s", buffer_kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs)
handle: deep_ep.Buffer = self.handle_cache.get_or_create( handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer) buffer_kwargs, deep_ep.Buffer)
# It is dangerous to set num sms outside this function. num_sms is not
# a part of the hash-key that identifies this object. If we are in a
# situation where we make objects with different num_sms, the hash key
# in get_or_create must be updated.
handle.set_num_sms(self.num_sms)
return handle return handle
def set_num_sms(self, num_sms: int):
import deep_ep
# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
# but not increase it.
if num_sms > self.num_sms:
num_sms = self.num_sms
deep_ep.Buffer.set_num_sms(num_sms)
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
""" """
@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
import deep_ep import deep_ep
# Defaults for internode and intranode are taken from DeepEP tests. # Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024 num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = num_local_experts num_qps_per_rank = num_local_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
handle: deep_ep.Buffer = self.handle_cache.get_or_create( handle: deep_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, deep_ep.Buffer) buffer_kwargs, deep_ep.Buffer)
return handle return handle
# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
return 0

View File

@ -10,8 +10,9 @@ import sys
import tempfile import tempfile
from collections.abc import Sequence from collections.abc import Sequence
from itertools import product from itertools import product
from typing import Optional from typing import Any, Optional
import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -56,6 +57,30 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
} }
} }
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
"thresholds": {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
},
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int,
input_tensor: torch.Tensor) -> bool:
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
return (world_size
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
def producer(batch_src: Sequence[int], def producer(batch_src: Sequence[int],
producer_queue, producer_queue,

View File

@ -60,6 +60,12 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor, def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
raise NotImplementedError raise NotImplementedError

View File

@ -7,6 +7,12 @@ import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import (
should_nccl_symm_mem_allreduce)
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -24,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase):
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
if "tp" not in unique_name: if "tp" not in unique_name:
# only tp uses custom allreduce # custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False use_custom_allreduce = False
use_torch_symm_mem = False
else: else:
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE) _ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl # ep does not use pynccl
use_pynccl = "ep" not in unique_name use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
@ -53,11 +62,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
) )
if is_symmetric_memory_enabled():
register_nccl_symmetric_ops(self.pynccl_comm)
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator( self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
@ -107,6 +118,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
raise ValueError(f"Unknown all2all backend: {all2all_backend}") raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_): def all_reduce(self, input_):
# since currently we perform copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and \
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
# always try quick reduce first, then custom allreduce, # always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*) # and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm qr_comm = self.qr_comm

View File

@ -17,6 +17,39 @@ from vllm.utils import current_stream
logger = init_logger(__name__) logger = init_logger(__name__)
_NCCL_SYMM_OPS_REGISTERED = False
def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context)
from vllm.utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED
if _NCCL_SYMM_OPS_REGISTERED:
return
_NCCL_SYMM_OPS_REGISTERED = True
def all_reduce_symmetric_with_copy_impl(
input_tensor: torch.Tensor) -> torch.Tensor:
with nccl_symm_mem_context(pynccl_comm):
symm_input = torch.empty_like(input_tensor)
symm_output = torch.empty_like(input_tensor)
symm_input.copy_(input_tensor)
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
return symm_output
def all_reduce_symmetric_with_copy_fake(
input_tensor: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input_tensor)
direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake,
)
class PyNcclCommunicator: class PyNcclCommunicator:
@ -67,6 +100,7 @@ class PyNcclCommunicator:
self.available = True self.available = True
self.disabled = False self.disabled = False
self.nccl_version = self.nccl.ncclGetRawVersion()
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0: if self.rank == 0:
@ -109,6 +143,7 @@ class PyNcclCommunicator:
def all_reduce(self, def all_reduce(self,
in_tensor: torch.Tensor, in_tensor: torch.Tensor,
out_tensor: torch.Tensor = None,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor: stream=None) -> torch.Tensor:
if self.disabled: if self.disabled:
@ -120,7 +155,8 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}") f"but the input tensor is on {in_tensor.device}")
out_tensor = torch.empty_like(in_tensor) if out_tensor is None:
out_tensor = torch.empty_like(in_tensor)
if stream is None: if stream is None:
stream = current_stream() stream = current_stream()
@ -288,3 +324,18 @@ class PyNcclCommunicator:
def group_end(self): def group_end(self):
self.nccl.ncclGroupEnd() self.nccl.ncclGroupEnd()
def register_comm_window(self, tensor: torch.Tensor):
return self.nccl.ncclCommWindowRegister(
self.comm,
buffer_type(tensor.data_ptr()),
tensor.numel() * tensor.element_size(),
1,
)
def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
size, 1)
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)

View File

@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import contextlib
import tempfile
from typing import Any, Optional
import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from torch.utils.cpp_extension import load_inline
from vllm import envs
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import find_nccl_include_paths
logger = init_logger(__name__)
nccl_allocator_source = """
#include <nccl.h>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
_allocator = None
_allocator_wrapper = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_nccl_allocator_failed_to_compile = False
_cached_pool_snapshot = None
def is_symmetric_memory_enabled():
global _nccl_allocator_failed_to_compile
return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile
def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
return False
for segment in _cached_pool_snapshot:
for block in segment["blocks"]:
if block["address"] == tensor.untyped_storage().data_ptr():
return True
return False
def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
def compile_nccl_allocator():
global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile
if not current_platform.is_cuda():
_nccl_allocator_failed_to_compile = True
return
try:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
nccl_include_paths = find_nccl_include_paths()
load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG",
is_python_module=False,
build_directory=out_dir,
extra_include_paths=nccl_include_paths,
)
_allocator_wrapper = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
)
_allocator = _allocator_wrapper.allocator()
except Exception as e:
_nccl_allocator_failed_to_compile = True
logger.warning(
"Failed to compile NCCL memory allocator. "
"Symmetric memory will be disabled. "
"This is expected if NCCL headers are not available. "
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
"containing the NCCL header. "
"Error: %s", str(e))
def get_nccl_mem_pool():
global _mem_pool, _nccl_allocator_failed_to_compile
if _mem_pool is None and not _nccl_allocator_failed_to_compile:
compile_nccl_allocator()
if _allocator is not None:
_mem_pool = torch.cuda.MemPool(_allocator)
return _mem_pool
def _cleanup_nccl_mem_pool():
global _mem_pool
_mem_pool = None
def _cleanup_nccl_allocator_wrapper():
global _allocator_wrapper
_allocator_wrapper = None
atexit.register(_cleanup_nccl_mem_pool)
atexit.register(_cleanup_nccl_allocator_wrapper)
class nccl_symm_mem_context:
def __init__(
self,
pynccl_comm: PyNcclCommunicator,
disabled: bool = False,
):
self.disabled = (disabled or not is_symmetric_memory_enabled()
or pynccl_comm.world_size == 1
or not current_platform.is_cuda()
or get_nccl_mem_pool() is None or version.parse(
torch.__version__) < version.parse("2.8.0.a0"))
if self.disabled:
self.pynccl_comm: Optional[PyNcclCommunicator] = None
self._mem_pool_ctx: contextlib.AbstractContextManager[
Any] = contextlib.nullcontext()
self.is_graph_capture = None
self.device = None
else:
self.pynccl_comm = pynccl_comm
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
def __enter__(self):
if self.disabled:
return self
assert (
self.pynccl_comm
is not None), "Symmetric memory requires pynccl to be initalized"
assert (
self.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id
is not None), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.disabled:
return
global _cached_pool_snapshot
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
_pool = get_nccl_mem_pool()
assert _pool is not None
_cached_pool_snapshot = _pool.snapshot()
assert self.pynccl_comm is not None
for segment in _cached_pool_snapshot:
if segment["address"] not in _registered_base_addrs:
self.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"])
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id)

View File

@ -41,6 +41,7 @@ logger = init_logger(__name__)
ncclResult_t = ctypes.c_int ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p ncclComm_t = ctypes.c_void_p
ncclWindow_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure): class ncclUniqueId(ctypes.Structure):
@ -222,6 +223,24 @@ class NCCLLibrary:
Function("ncclGroupStart", ncclResult_t, []), Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd(); # ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []), Function("ncclGroupEnd", ncclResult_t, []),
# ncclResult_t ncclCommWindowRegister(
# ncclComm_t comm, void* buff, size_t size,
# ncclWindow_t* win, int winFlags);
Function(
"ncclCommWindowRegister",
ncclResult_t,
[
ncclComm_t,
buffer_type,
ctypes.c_size_t,
ctypes.POINTER(ncclWindow_t),
ctypes.c_int,
],
),
# ncclResult_t ncclCommWindowDeregister(
# ncclComm_t comm, ncclWindow_t win);
Function("ncclCommWindowDeregister", ncclResult_t,
[ncclComm_t, ncclWindow_t]),
] ]
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
@ -271,10 +290,14 @@ class NCCLLibrary:
error_str = self.ncclGetErrorString(result) error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}") raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str: def ncclGetRawVersion(self) -> int:
version = ctypes.c_int() version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value) # something like 21903
return version.value
def ncclGetVersion(self) -> str:
version_str = str(self.ncclGetRawVersion())
# something like 21903 --> "2.19.3" # something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0") major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0") minor = version_str[1:3].lstrip("0")
@ -375,6 +398,17 @@ class NCCLLibrary:
def ncclGroupEnd(self) -> None: def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
size: int, win_flags: int) -> ncclWindow_t:
window = ncclWindow_t()
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags))
return window
def ncclCommWindowDeregister(self, comm: ncclComm_t,
window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
__all__ = [ __all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",

View File

@ -330,6 +330,8 @@ class EngineArgs:
enable_dbo: bool = ParallelConfig.enable_dbo enable_dbo: bool = ParallelConfig.enable_dbo
dbo_decode_token_threshold: int = \ dbo_decode_token_threshold: int = \
ParallelConfig.dbo_decode_token_threshold ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = \
ParallelConfig.dbo_prefill_token_threshold
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \ expert_placement_strategy: ExpertPlacementStrategy = \
@ -698,6 +700,9 @@ class EngineArgs:
parallel_group.add_argument( parallel_group.add_argument(
"--dbo-decode-token-threshold", "--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"]) **parallel_kwargs["dbo_decode_token_threshold"])
parallel_group.add_argument(
"--dbo-prefill-token-threshold",
**parallel_kwargs["dbo_prefill_token_threshold"])
parallel_group.add_argument("--enable-eplb", parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"]) **parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config", parallel_group.add_argument("--eplb-config",
@ -1316,6 +1321,7 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
enable_dbo=self.enable_dbo, enable_dbo=self.enable_dbo,
dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config, eplb_config=self.eplb_config,
expert_placement_strategy=self.expert_placement_strategy, expert_placement_strategy=self.expert_placement_strategy,

View File

@ -317,7 +317,8 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
) )
output_items.append(response_item) output_items.append(response_item)
elif recipient is not None and (recipient.startswith("python") elif recipient is not None and (recipient.startswith("python")
or recipient.startswith("browser")): or recipient.startswith("browser")
or recipient.startswith("container")):
for content in message.content: for content in message.content:
reasoning_item = ResponseReasoningItem( reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}", id=f"rs_{random_uuid()}",

View File

@ -182,15 +182,19 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
VLLM_DBO_COMM_SMS: int = 20
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
VLLM_USE_NCCL_SYMM_MEM: bool = False
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
def get_default_cache_root(): def get_default_cache_root():
@ -1366,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use pytorch symmetric memory for allreduce # Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM": "VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
# Allows vllm to find tuned config under customized folder # Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": "VLLM_TUNED_CONFIG_FOLDER":
@ -1392,6 +1396,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
"VLLM_OBJECT_STORAGE_SHM_BUFFER"), "VLLM_OBJECT_STORAGE_SHM_BUFFER"),
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
"VLLM_DEEPEP_BUFFER_SIZE_MB":
lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")),
# The number of SMs to allocate for communication kernels when running DBO
# the rest of the SMs on the device will be allocated to compute
"VLLM_DBO_COMM_SMS":
lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
# Valid values are container,code_interpreter,web_search_preview # Valid values are container,code_interpreter,web_search_preview
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS": "GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
@ -1399,6 +1412,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
["container", ["container",
"code_interpreter", "code_interpreter",
"web_search_preview"]), "web_search_preview"]),
# Flag to enable NCCL symmetric memory allocation and registration
"VLLM_USE_NCCL_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
# NCCL header path
"VLLM_NCCL_INCLUDE_PATH":
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]

View File

@ -24,11 +24,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.input_size = self.base_layer.input_size self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
self.output_slices: tuple[int, ...] self.output_slices: tuple[int, ...]
self.tp_size: int
self.output_size: int self.output_size: int
self.n_slices: int self.n_slices: int

View File

@ -8,9 +8,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import tensor_model_parallel_all_gather
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
@ -85,7 +83,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# inconsistent when TP is greater than 1. # inconsistent when TP is greater than 1.
self.is_merged_col_linear = type( self.is_merged_col_linear = type(
base_layer) is MergedColumnParallelLinear base_layer) is MergedColumnParallelLinear
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size = self.base_layer.output_size_per_partition self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer # There is only one LoRA layer
self.n_slices = 1 self.n_slices = 1
@ -97,22 +94,20 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# Applicable to cases where the base_layer is # Applicable to cases where the base_layer is
# MergedColumnParallelLinear. # MergedColumnParallelLinear.
if self.is_merged_col_linear: if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2 shard_size = self.output_size // 2
offset = lora_b.shape[0] // 2 offset = lora_b.shape[0] // 2
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) * left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) *
shard_size, :] shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset + right_weight = lora_b[offset + self.tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :] (self.tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0) lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is # Applicable to cases where the base_layer is
# ColumnParallelLinear. # ColumnParallelLinear.
else: else:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :] lora_b = lora_b[start_idx:end_idx, :]
return lora_b return lora_b
@ -120,10 +115,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# TODO: Fix the slicing logic of bias. # TODO: Fix the slicing logic of bias.
if bias is None: if bias is None:
return bias return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx] bias = bias[start_idx:end_idx]
return bias return bias
@ -144,7 +138,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply(input_, bias) output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output: if self.base_layer.gather_output and self.tp_size > 1:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
else: else:
@ -185,8 +179,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
QKVParallelLinear]) -> None: QKVParallelLinear]) -> None:
super().__init__(base_layer) super().__init__(base_layer)
# There are two LoRA layers # There are two LoRA layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp # the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size # we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes output_sizes = self.base_layer.output_sizes
@ -341,9 +333,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.n_slices = 1 self.n_slices = 1
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank self.q_shard_id = self.tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[self.q_proj_shard_size * lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1), :] (self.q_shard_id + 1), :]
@ -397,8 +389,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
super().__init__(base_layer) super().__init__(base_layer)
# There are three LoRA layer. # There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes) self.n_slices = len(self.base_layer.output_sizes)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads * self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size) self.base_layer.head_size)
@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
# Therefore, the sharding of `lora_a` only needs to correspond with the # Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation. # gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2] shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :] lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a return lora_a
@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
""" """
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2] shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx:start_idx + shard_size, :] lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a return lora_a

View File

@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None: def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(base_layer, ) super().__init__(base_layer, )
# To ensure interface compatibility, set to 1 always. # To ensure interface compatibility, set to 1 always.
self.tp_size = 1
self.output_size = self.base_layer.output_size self.output_size = self.base_layer.output_size
self.n_slices = 1 self.n_slices = 1

View File

@ -8,9 +8,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (split_tensor_along_last_dim,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
# yapf: disable # yapf: disable
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__(base_layer) super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
# reset input_size # reset input_size
self.input_size = self.base_layer.input_size_per_partition self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size self.output_size = self.base_layer.output_size
self.tp_rank = get_tensor_model_parallel_rank()
# There is only one LoRA layer. # There is only one LoRA layer.
self.n_slices = 1 self.n_slices = 1
@ -68,12 +63,12 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
else: else:
# TODO: simplify code below # TODO: simplify code below
splitted_input = split_tensor_along_last_dim( splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size) input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous() input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
output_parallel = self.apply(input_parallel) output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1: if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output_ = output_parallel output_ = output_parallel
@ -154,8 +149,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
buffer, x, self.lora_a_stacked, 1.0) buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace(): if not current_platform.can_update_inplace():
buffer = shrunk_buffer buffer = shrunk_buffer
if self.tp_size>1:
buffer = tensor_model_parallel_all_reduce(buffer) buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce # following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output # by adding the column partitioned lora output to a slice of output

View File

@ -48,11 +48,11 @@ class LoRALayerWeights:
@property @property
def input_dim(self) -> int: def input_dim(self) -> int:
return self.lora_a.shape[0] return self.lora_a.shape[1]
@property @property
def output_dim(self) -> int: def output_dim(self) -> int:
return self.lora_b.shape[1] return self.lora_b.shape[0]
@property @property
def is_packed(self) -> bool: def is_packed(self) -> bool:

View File

@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
from vllm.utils import round_up from vllm.utils import round_up
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
dbo_switch_to_compute, dbo_switch_to_compute_sync,
dbo_yield_and_switch_from_comm_to_compute,
dbo_yield_and_switch_from_compute_to_comm)
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.async_prepare = True self.async_prepare = True
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. Under DBO microbatching we must track one handle per
# combine function. # micro-batch to avoid races between threads.
self.handle = None self.handles = [None, None]
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
has_scales = token_scales is not None has_scales = token_scales is not None
# We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()
(num_tokens_per_rank, num_tokens_per_rdma_rank, (num_tokens_per_rank, num_tokens_per_rdma_rank,
dispatch_expert_num_tokens, is_token_in_rank, dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout( event) = self.buffer.get_dispatch_layout(
@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
( (
token_data, expert_topk_ids, expert_topk_weights, token_data, expert_topk_ids, expert_topk_weights,
expert_num_tokens_per_expert_list, self.handle, event expert_num_tokens_per_expert_list, handle, event
) = self.buffer.dispatch( ) = self.buffer.dispatch(
x=token_data, x=token_data,
handle=None, handle=None,
@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1, expert_alignment=1,
config=self._get_dispatch_config(), config=self._get_dispatch_config(),
previous_event=None, previous_event=None,
async_finish=self.async_prepare, async_finish=self.async_prepare and not dbo_enabled(),
allocate_on_comm_stream=False) allocate_on_comm_stream=False)
# record the handle for this ubatch
a2a_idx = dbo_current_ubatch_id()
self.handles[a2a_idx] = handle
dbo_switch_to_compute_sync()
return lambda: self._receiver( return lambda: self._receiver(
event, event,
has_scales, has_scales,
@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if self.async_prepare: if event.event is not None:
event.current_stream_wait() event.current_stream_wait()
if has_scales: if has_scales:
@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]: ) -> mk.ReceiverType:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1q_scale = None a1q_scale = None
a1_post_scale = quant_config.a1_scale a1_post_scale = quant_config.a1_scale
return (lambda *args: None, return self._do_dispatch(tokens=a1q,
self._do_dispatch(tokens=a1q, token_scales=a1q_scale,
token_scales=a1q_scale, rank_topk_ids=topk_ids,
rank_topk_ids=topk_ids, rank_topk_weights=topk_weights,
rank_topk_weights=topk_weights, num_experts=num_experts,
num_experts=num_experts, a1_scale=a1_post_scale,
a1_scale=a1_post_scale, quant_config=quant_config)
quant_config=quant_config))
def prepare( def prepare(
self, self,
@ -252,10 +267,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids, receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts,
num_experts, expert_map, expert_map, apply_router_weight_on_input,
apply_router_weight_on_input, quant_config)
quant_config)
return receiver() return receiver()
def _finalize( def _finalize(
@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_async: bool, do_async: bool,
) -> Optional[Callable]: ) -> Optional[Callable]:
assert self.handle is not None a2a_idx = dbo_current_ubatch_id()
handle = self.handles[a2a_idx]
assert handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the # fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank. # tokens from the all2all reach this EP rank.
@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids=topk_ids, topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
dbo_yield_and_switch_from_compute_to_comm()
combined_x, _, event = self.buffer.combine( combined_x, _, event = self.buffer.combine(
x=fused_expert_output, x=fused_expert_output,
handle=self.handle, handle=handle,
topk_weights=None, topk_weights=None,
config=self._get_combine_config(), config=self._get_combine_config(),
previous_event=None, previous_event=None,
async_finish=do_async, async_finish=do_async and not dbo_enabled(),
allocate_on_comm_stream=False) allocate_on_comm_stream=False)
dbo_switch_to_compute()
if do_async: if do_async:
def _receiver(): def _receiver():
event.current_stream_wait() if event.event is not None:
event.current_stream_wait()
dbo_switch_to_comm()
# Respect inplace outputs. # Respect inplace outputs.
output.copy_(combined_x, non_blocking=True) output.copy_(combined_x, non_blocking=True)
return lambda: _receiver() # TODO(lucas): refactor the modular kernel so this will be
# handled there
dbo_yield_and_switch_from_comm_to_compute()
return _receiver
else: else:
# TODO(lucas): support this case with the refactored modular kernel
assert not dbo_enabled()
# Respect inplace outputs. # Respect inplace outputs.
output.copy_(combined_x, non_blocking=True) output.copy_(combined_x, non_blocking=True)
return None return None

View File

@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
do_async: bool, do_async: bool,
) -> Optional[Callable]: ) -> tuple[Callable, Callable]:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")
@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=do_recv_hook, return_recv_hook=do_recv_hook,
out=output) out=output)
return recv_hook return recv_hook, lambda: None
def finalize_async( def finalize_async(
self, self,
@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable: ) -> tuple[Callable, Callable]:
recv_hook = self._finalize( return self._finalize(
output, output,
fused_expert_output, fused_expert_output,
topk_weights, topk_weights,
@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
weight_and_reduce_impl, weight_and_reduce_impl,
do_async=True, do_async=True,
) )
assert recv_hook is not None
return recv_hook
def finalize( def finalize(
self, self,

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens) _resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook, from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield) dbo_register_recv_hook, dbo_yield)
# #
@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, ReceiverType]: ) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
""" """
Perform any quantization (and/or) dispatching needed for this kernel Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers. but do not wait for results from other workers.
@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the - apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching. activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other Returns a callback or a hook callback pair that when invoked waits for
workers and has the same return signature as `prepare`, e.g. results from other workers and has the same return signature as
`prepare`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
e.g.
receiver = obj.prepare_async(...) ret = obj.prepare_async(...)
if isinstance(ret, tuple):
hook, receiver = ret
hook()
if hook is not None:
a, a_scales, expert_meta, topk_ids, topk_weights = receiver() a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to: is equivalent to:
@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce, weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable: ) -> Union[tuple[Callable, Callable], Callable]:
""" """
Perform any combine plus apply weights and perform a reduction on the Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers. fused experts output but do not wait for results from other workers.
@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
- weight_and_reduce_impl: An optional TopKWeightAndReduce - weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation. implementation.
Returns a callback that when invoked waits for results from other Returns a callback or a hook callback pair that when invoked waits for
workers and has the same return signature as `finalize`, e.g. results from other workers and has the same return signature as
`finalize`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
receiver = obj.finalize_async(output, ...) ret = obj.finalize_async(output, ...)
... output not valid yet ... ... output not valid yet ...
if isinstance(ret, tuple):
hook, receiver = ret
hook()
receiver() receiver()
... output valid here ... ... output valid here ...
@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component layer due to any layer specific state that may be used by the component
objects. objects.
""" """
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer() class SharedBuffers:
workspace2_buffer = SharedResizableBuffer()
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__( def __init__(
self, self,
@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module):
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta) expert_tokens_meta)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.
workspace13 = self.workspace13_buffer.get(workspace13_shape, workspace13 = buffers.workspace13.get(workspace13_shape,
device=a1.device, device=a1.device,
dtype=workspace_dtype) dtype=workspace_dtype)
workspace2 = self.workspace2_buffer.get(workspace2_shape, workspace2 = buffers.workspace2.get(workspace2_shape,
device=a1.device, device=a1.device,
dtype=workspace_dtype) dtype=workspace_dtype)
assert fused_out is None or fused_out.shape == fused_out_shape, ( assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}") f"fused_out {fused_out.shape} but expected {fused_out_shape}")
@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module):
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta) expert_tokens_meta)
fused_out = self.fused_out_buffer.get(fused_out_shape, ubatch_idx = dbo_current_ubatch_id()
device=a1q.device, buffers = self.shared_buffers[ubatch_idx]
dtype=a1.dtype) fused_out = buffers.fused_out.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
def slice_input_tensors( def slice_input_tensors(
chunk_idx: int chunk_idx: int
@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't # We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize # support async prepare/finalize
# TODO(lucas): enable in follow-up
assert not dbo_enabled() assert not dbo_enabled()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
else: else:
# Overlap shared expert compute with all2all dispatch. # Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook() dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async( prepare_ret = self.prepare_finalize.prepare_async(
a1, a1,
topk_weights, topk_weights,
topk_ids, topk_ids,
@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config, self.fused_experts.quant_config,
) )
# If DBO is being used, register the hook with the ubatch context # TODO(lucas): refactor this in the alternative schedules followup
# and call it in dbo_maybe_run_recv_hook instead of passing it to # currently unpack if we have hook + receiver pair or just
# the receiver. # receiver (see finalize_async docstring)
dbo_register_recv_hook(hook) hook, receiver = prepare_ret \
dbo_yield() if isinstance(prepare_ret, tuple) else (None, prepare_ret)
if not dbo_enabled():
hook() if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
else:
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver() _expert_topk_weights) = receiver()
@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(a1) shared_output = self.shared_experts(a1)
else: else:
recv_hook = self.prepare_finalize.finalize_async( finalize_ret = self.prepare_finalize.finalize_async(
output, output,
fused_out, fused_out,
topk_weights, topk_weights,
@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(a1) shared_output = self.shared_experts(a1)
assert recv_hook is not None # TODO(lucas): refactor this in the alternative schedules followup
dbo_register_recv_hook(recv_hook) # currently unpack if we have hook + receiver pair or just
dbo_yield() # receiver (see finalize_async docstring)
if not dbo_enabled(): hook, receiver = finalize_ret \
recv_hook() if isinstance(finalize_ret, tuple) else (None, finalize_ret)
if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
else:
hook()
receiver()
if self.shared_experts is None: if self.shared_experts is None:
return output return output

View File

@ -644,6 +644,14 @@ class CompressedTensorsConfig(QuantizationConfig):
# If no matches, return None # If no matches, return None
return None return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
if (weight_quant is not None
and weight_quant.strategy == QuantizationStrategy.BLOCK):
return True
return False
@staticmethod @staticmethod
def supports_cutlass_24( def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs], weight_quant: Optional[QuantizationArgs],

View File

@ -13,6 +13,7 @@ from compressed_tensors.quantization import (ActivationOrdering,
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
@ -31,6 +32,9 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl) select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
requant_weight_ue8m0_inplace)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new, check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales) marlin_moe_permute_scales)
@ -45,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
logger = init_logger(__name__) logger = init_logger(__name__)
@ -505,10 +510,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.weight_quant.strategy == QuantizationStrategy.CHANNEL self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN) and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel): if not (per_tensor or per_channel):
raise ValueError( assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
"For FP8 Fused MoE layers, we require per tensor " self.weight_block_size = self.weight_quant.block_structure
"or channelwise, dynamic per token quantization. Found " assert self.weight_quant.dynamic is not None
f"{self.weight_quant}, {self.input_quant}") else:
self.weight_block_size = None
self.block_quant = self.weight_block_size is not None
self.static_input_scales = not self.input_quant.dynamic self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel: if self.static_input_scales and per_channel:
@ -519,7 +526,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89) self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN) or envs.VLLM_TEST_FORCE_FP8_MARLIN
and not self.block_quant)
# Disable marlin for rocm # Disable marlin for rocm
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
@ -531,8 +539,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# cutlass path # cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant) self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( self.use_cutlass = not self.block_quant and (
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
or self.is_fp8_w8a8_sm100)
self.disable_expert_map = False self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -547,6 +556,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.weight_block_size[0],
self.weight_block_size[1],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}.")
if (tp_size > 1
and intermediate_size_per_partition % block_k != 0):
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty( w13_weight = torch.nn.Parameter(torch.empty(
num_experts, num_experts,
@ -602,6 +636,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 *
((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, (hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES # INPUT_SCALES
if self.static_input_scales: if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones( w13_input_scale = torch.nn.Parameter(torch.ones(
@ -706,6 +761,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
del layer.w2_input_scale del layer.w2_input_scale
if self.use_cutlass: if self.use_cutlass:
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
device = layer.w13_weight.device device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same # ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full( self.ab_strides1_c_strides2 = torch.full(
@ -724,6 +780,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
device=device, device=device,
dtype=torch.int64) dtype=torch.int64)
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale):
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale)
if expert_weight_is_col_major(layer.w2_weight_scale):
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale)
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self) -> Optional[mk.FusedMoEPrepareAndFinalize]: self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if self.use_marlin or self.rocm_aiter_moe_enabled: if self.use_marlin or self.rocm_aiter_moe_enabled:
@ -777,9 +856,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
return experts return experts
# triton path # triton path
from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonOrDeepGemmExperts)
BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
assert not self.rocm_aiter_moe_enabled and not self.use_marlin assert not self.rocm_aiter_moe_enabled and not self.use_marlin
@ -790,14 +870,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert max_num_tokens_per_rank is not None assert max_num_tokens_per_rank is not None
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts( return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(), num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
else: else:
logger.debug("TritonExperts(%s)", self.__class__.__name__) logger.debug("TritonOrDeepGemmExperts(%s)",
return TritonExperts(self.moe_quant_config) self.__class__.__name__)
return TritonOrDeepGemmExperts(self.moe_quant_config,
allow_deep_gemm=True)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
@ -816,6 +898,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
per_out_ch_quant=per_channel_quant, per_out_ch_quant=per_channel_quant,
block_shape=layer.weight_block_size,
) )
def apply( def apply(

View File

@ -11,7 +11,7 @@ from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support, W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
@ -41,16 +41,30 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
self.weight_block_size = self.weight_quant.block_structure self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# lovelace and up # lovelace and up
@ -141,13 +155,14 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if layer.weight_block_size is not None: if self.weight_block_size is not None:
return apply_fp8_block_linear( return self.w8a8_block_fp8_linear.apply(
layer,
input=x, input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias, bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, )
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,

View File

@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs(
return M, N, K, C return M, N, K, C
def w8a8_block_fp8_matmul_deepgemm( def w8a8_deepgemm_block_scaled_mm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm(
return C return C
def w8a8_block_fp8_matmul_deepgemm_fake( def w8a8_deepgemm_block_scaled_mm_fake(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="w8a8_block_fp8_matmul_deepgemm", op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_block_fp8_matmul_deepgemm, op_func=w8a8_deepgemm_block_scaled_mm,
mutates_args=[], mutates_args=[],
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )

View File

@ -31,12 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl, swap_w13_to_w31) select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_block_linear, check_aiter_fp8_linear_support, W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor, create_fp8_weight_parameter, expert_weight_is_col_major,
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
validate_fp8_block_shape) requant_weight_ue8m0_inplace, validate_fp8_block_shape)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
prepare_moe_fp8_layer_for_marlin) prepare_moe_fp8_layer_for_marlin)
@ -64,12 +64,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__) logger = init_logger(__name__)
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
class Fp8Config(QuantizationConfig): class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
@ -240,15 +234,28 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size = self.quant_config.weight_block_size self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static" self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass if self.weight_block_size:
if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.act_q_group_shape = GroupShape.PER_TOKEN
else: else:
self.act_q_group_shape = GroupShape.PER_TENSOR # Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
self.fp8_linear = Fp8LinearOp( if self.block_quant:
act_quant_static=self.act_q_static, assert not self.act_q_static
act_quant_group_shape=self.act_q_group_shape) assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape)
def create_weights( def create_weights(
self, self,
@ -397,12 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias) bias=bias)
if self.block_quant: if self.block_quant:
return apply_fp8_block_linear( assert self.weight_block_size is not None
layer,
return self.w8a8_block_fp8_linear.apply(
input=x, input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias, bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, )
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
return self.fp8_linear.apply(input=x, return self.fp8_linear.apply(input=x,
weight=layer.weight, weight=layer.weight,
@ -660,10 +670,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do # DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons. # it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if _is_col_major(layer.w13_weight_scale_inv): if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \ layer.w13_weight_scale_inv = \
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
if _is_col_major(layer.w2_weight_scale_inv): if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \ layer.w2_weight_scale_inv = \
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
@ -811,10 +821,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
# Ensure column-major TMA alignment expected by DeepGEMM. # Ensure column-major TMA alignment expected by DeepGEMM.
if _is_col_major(layer.w13_weight_scale_inv): if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv) layer.w13_weight_scale_inv)
if _is_col_major(layer.w2_weight_scale_inv): if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv) layer.w2_weight_scale_inv)

View File

@ -27,11 +27,14 @@ class QuantFP8(CustomOp):
This CustomOp supports both static and dynamic quantization. This CustomOp supports both static and dynamic quantization.
""" """
def __init__(self, def __init__(
static: bool, self,
group_shape: GroupShape, static: bool,
num_token_padding: Optional[int] = None, group_shape: GroupShape,
column_major_scales: bool = False): num_token_padding: Optional[int] = None,
column_major_scales: bool = False,
use_ue8m0: Optional[bool] = None, # for Torch compile
):
""" """
:param static: static or dynamic quantization :param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
@ -46,6 +49,7 @@ class QuantFP8(CustomOp):
self.group_shape = group_shape self.group_shape = group_shape
self.num_token_padding = num_token_padding self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales self.column_major_scales = column_major_scales
self.use_ue8m0 = use_ue8m0
self.is_group_quant = group_shape.is_per_group() self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant: if self.is_group_quant:
@ -70,7 +74,8 @@ class QuantFP8(CustomOp):
x, x,
group_size=self.group_size, group_size=self.group_size,
column_major_scales=self.column_major_scales, column_major_scales=self.column_major_scales,
dtype=_FP8_DTYPE) dtype=_FP8_DTYPE,
use_ue8m0=self.use_ue8m0)
assert (scale is not None) == self.static assert (scale is not None) == self.static
assert scale_ub is None or (not self.static and self.group_shape assert scale_ub is None or (not self.static and self.group_shape
@ -137,7 +142,10 @@ class QuantFP8(CustomOp):
x_grouped = x.view(-1, num_groups, self.group_size) x_grouped = x.view(-1, num_groups, self.group_size)
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) scales_raw = absmax / _FP8_MAX
if self.use_ue8m0:
scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw)))
scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR)
x_scaled = x_grouped / scales x_scaled = x_grouped / scales
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
@ -151,6 +159,6 @@ class QuantFP8(CustomOp):
scales = scales.reshape(orig_shape[:-1] + (num_groups, )) scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
if self.column_major_scales: if self.column_major_scales:
scales = scales.transpose(-2, -1).contiguous() scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2)
return x_quant, scales return x_quant, scales

View File

@ -13,8 +13,9 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast) GroupShape, group_broadcast)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.parameter import (BlockQuantScaleParameter, from vllm.model_executor.parameter import (BlockQuantScaleParameter,
@ -24,6 +25,7 @@ from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear) should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm( def cutlass_scaled_mm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -42,15 +46,17 @@ def cutlass_scaled_mm(
Bs: torch.Tensor, Bs: torch.Tensor,
block_size: list[int], block_size: list[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
is_hopper: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm( return ops.cutlass_scaled_mm(
A, A,
B.T, B.T,
out_dtype=output_dtype, out_dtype=output_dtype,
scale_a=As, scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time # SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None scale_b=Bs if block_size is not None and is_hopper else Bs.T)
and current_platform.is_device_capability(90) else Bs.T)
def rocm_aiter_gemm_w8a8_blockscale_impl( def rocm_aiter_gemm_w8a8_blockscale_impl(
@ -98,122 +104,189 @@ if current_platform.is_rocm():
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
def dispatch_w8a8_blockscale_func( # TODO we should be able to change the type of block_size to GroupShape
use_cutlass: bool, use_aiter_and_is_supported: bool # after we resolve GroupShape compilation issue
) -> Callable[[ # https://github.com/vllm-project/vllm/issues/25270
torch.Tensor, def _w8a8_triton_block_scaled_mm_func(
torch.Tensor, qx: torch.Tensor,
torch.Tensor, weight: torch.Tensor,
torch.Tensor, x_scale: torch.Tensor,
list[int], weight_scale: torch.Tensor,
torch.dtype, block_size: list[int],
], torch.Tensor]: output_dtype: torch.dtype,
if use_cutlass: ) -> torch.Tensor:
return cutlass_scaled_mm return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
if (use_aiter_and_is_supported): block_size, output_dtype)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
return w8a8_block_fp8_matmul
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty((qx.size(0), weight.size(0)),
dtype=output_dtype,
device=qx.device)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
mutates_args=[],
fake_impl=_w8a8_triton_block_scaled_mm_fake,
dispatch_key="CUDA",
)
# TODO fix ROCm->Triton custom path: # TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397 # https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear( class W8A8BlockFp8LinearOp:
input: torch.Tensor, """
weight: torch.Tensor, This class executes a Blocked FP8 linear layer using cutlass if supported
block_size: list[int], and torch.scaled_mm otherwise.
weight_scale: torch.Tensor, """
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm_for_fp8_linear(output_dtype, weight): def __init__(
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = \
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported)
self.deepgemm_input_quant_op = (QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
else None)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1]) input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]] output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
q_input, x_scale = per_token_group_quant_fp8( if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
input_2d, self.is_deep_gemm_supported):
block_size[1], output = self._run_deepgemm(input, weight, weight_scale)
column_major_scales=True, if bias is not None:
) output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
# ensure DeepGEMM-backed custom op is registered before use # ensure DeepGEMM-backed custom op is registered before use
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( assert self.deepgemm_input_quant_op is not None
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
q_input, q_input,
weight, weight,
x_scale, x_scale,
weight_scale, weight_scale,
block_size, self.weight_group_shape,
output_dtype=output_dtype) output_dtype=input_2d.dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
w8a8_blockscale_func = dispatch_w8a8_blockscale_func( def _run_cutlass(
cutlass_block_fp8_supported, use_aiter_and_is_supported) self,
if cutlass_block_fp8_supported: input_2d: torch.Tensor,
num_pad = 0 weight: torch.Tensor,
if current_platform.is_device_capability(90): weight_scale: torch.Tensor,
# pad first dimension to be divisible by 4 due to ) -> torch.Tensor:
# cutlass blockwise gemm limitation for hopper assert self.input_quant_op is not None
num_pad = 4 - (input_2d.shape[0] % 4) if self.is_hopper:
if num_pad > 0: # We pad unconditionally (even if shape is already divisible by 4)
input_2d = torch.nn.functional.pad(input_2d, # to support dynamic shape for input_2d.shape[0] in torch.compile
(0, 0, 0, num_pad), x = torch.nn.functional.pad(input_2d,
"constant", 0) (0, 0, 0, -input_2d.shape[0] % 4))
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if num_pad > 0:
output = output[:-num_pad]
else:
if use_aiter_and_is_supported:
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else: else:
q_input, x_scale = per_token_group_quant_fp8( x = input_2d
input_2d, block_size[1], column_major_scales=False)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, q_input, x_scale = self.input_quant_op(x)
block_size, input.dtype) output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale,
list(self.weight_group_shape),
input_2d.dtype, self.is_hopper)
output = output[0:input_2d.shape[0], ...]
return output
if bias is not None: def _run_aiter(
output = output + bias self,
return output.to(dtype=input.dtype).view(*output_shape) input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.input_quant_op is not None
q_input, x_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
input_2d.dtype)
def apply_w8a8_block_fp8_linear_fake( def _dispatch_w8a8_blockscale_op(
input: torch.Tensor, self,
weight: torch.Tensor, use_cutlass: bool,
block_size: list[int], use_aiter_and_is_supported: bool,
weight_scale: torch.Tensor, ) -> tuple[Callable[[
input_scale: Optional[torch.Tensor] = None, torch.Tensor,
bias: Optional[torch.Tensor] = None, torch.Tensor,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, torch.Tensor,
use_aiter_and_is_supported: bool = False, ], torch.Tensor], Optional[QuantFP8]]:
) -> torch.Tensor: if use_cutlass:
output_shape = [*input.shape[:-1], weight.shape[0]] return self._run_cutlass, (QuantFP8(False,
return torch.empty(output_shape, dtype=input.dtype, device=input.device) self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False))
if not current_platform.is_cpu(): if use_aiter_and_is_supported:
direct_register_custom_op( return self._run_aiter, None
op_name="apply_w8a8_block_fp8_linear", return self._run_triton, (QuantFP8(False,
op_func=apply_w8a8_block_fp8_linear, self.act_quant_group_shape,
mutates_args=[], column_major_scales=False,
fake_impl=apply_w8a8_block_fp8_linear_fake, use_ue8m0=False))
)
def input_to_float8( def input_to_float8(
@ -465,7 +538,7 @@ def per_token_group_quant_fp8(
@triton.jit @triton.jit
def _w8a8_block_fp8_matmul( def _w8a8_triton_block_scaled_mm(
# Pointers to inputs and output # Pointers to inputs and output
A, A,
B, B,
@ -590,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return None return None
def w8a8_block_fp8_matmul( def w8a8_triton_block_scaled_mm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
@ -650,7 +723,7 @@ def w8a8_block_fp8_matmul(
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), ) triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid]( _w8a8_triton_block_scaled_mm[grid](
A, A,
B, B,
C, C,
@ -997,20 +1070,7 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
layer.weight_scale.data.T.contiguous(), requires_grad=False) layer.weight_scale.data.T.contiguous(), requires_grad=False)
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, def expert_weight_is_col_major(x: torch.Tensor) -> bool:
bias: Optional[torch.Tensor], assert x.dim() == 3
cutlass_block_fp8_supported: bool, b, m, n = x.shape
use_aiter_and_is_supported: bool) -> torch.Tensor: return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
"""Apply block-wise FP8 linear operation."""
assert layer.weight_block_size is not None
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=input,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)

View File

@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor: bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None: qinput.shape[0] == 1 and \
qinput.shape[1] % 16 == 0 and \
((bias is None) or (bias.dtype == out_dtype)) :
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count()) current_platform.get_cu_count(), bias)
else: else:
output = torch._scaled_mm(qinput, output = torch._scaled_mm(qinput,
weight, weight,

View File

@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl(
k = weight.shape[1] k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
x.dtype in [torch.float16, torch.bfloat16] \ x.dtype in [torch.float16, torch.bfloat16] \
and k % 8 == 0 and bias is None) and k % 8 == 0)
if use_skinny is not True: if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias) return torch.nn.functional.linear(x, weight, bias)
@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl(
cu_count = current_platform.get_cu_count() cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n <= 4: if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count) out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.view(*x.shape[:-1], weight.shape[0]) return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192: elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
out = ops.LLMM1(weight, x_view, 4) out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0]) return out.view(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias) return torch.nn.functional.linear(x, weight, bias)

View File

@ -266,24 +266,24 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
if structured_outputs_config.reasoning_parser == "": if structured_outputs_config.reasoning_parser == "":
structured_outputs_config.reasoning_parser = "openai_gptoss" structured_outputs_config.reasoning_parser = "openai_gptoss"
# Increase the max capture size from 512 to 1024 for performance. # Increase the max capture size from 512 to 992 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs # NOTE(woosuk): This will increase the number of CUDA graphs
# from 67 to 83. # from 67 to 81.
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if len(scheduler_config.cuda_graph_sizes) == 1: if len(scheduler_config.cuda_graph_sizes) == 1:
max_capture_size = scheduler_config.cuda_graph_sizes[0] max_capture_size = scheduler_config.cuda_graph_sizes[0]
# FIXME(woosuk): When using full cuda graph with FA3, the max # FIXME(woosuk): When using full cuda graph with FA3, the max
# supported size is 992. # supported size is 992.
if max_capture_size < 1024: if max_capture_size < 992:
cuda_graph_sizes = [1, 2, 4] cuda_graph_sizes = [1, 2, 4]
# Step size 8 for small batch sizes # Step size 8 for small batch sizes
cuda_graph_sizes += [i for i in range(8, 256, 8)] cuda_graph_sizes += [i for i in range(8, 256, 8)]
# Step size 16 for larger batch sizes # Step size 16 for larger batch sizes
cuda_graph_sizes += [i for i in range(256, 1025, 16)] cuda_graph_sizes += [i for i in range(256, 993, 16)]
scheduler_config.cuda_graph_sizes = cuda_graph_sizes scheduler_config.cuda_graph_sizes = cuda_graph_sizes
logger.info( logger.info(
"Overriding max cuda graph capture size to " "Overriding max cuda graph capture size to "
"%d for performance.", 1024) "%d for performance.", 992)
class MambaModelConfig(VerifyAndUpdateConfig): class MambaModelConfig(VerifyAndUpdateConfig):

View File

@ -134,6 +134,11 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = vllm_config. \ self.config = vllm_config. \
speculative_config.draft_model_config.hf_config speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers( target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config, self.model = LlamaModel(vllm_config=vllm_config,

View File

@ -203,6 +203,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = vllm_config. \ self.config = vllm_config. \
speculative_config.draft_model_config.hf_config speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
# default to the base vocab size when absent
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers( target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)

View File

@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module(
""" """
assert isinstance(m, FusedMoE) assert isinstance(m, FusedMoE)
w13 = m.w13_weight w13 = m.w13_weight
w13_s = m.w13_weight_scale_inv w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale)
w2 = m.w2_weight w2 = m.w2_weight
w2_s = m.w2_weight_scale_inv w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale)
num_topk = m.top_k num_topk = m.top_k
assert isinstance(w13, torch.Tensor) assert isinstance(w13, torch.Tensor)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
import warnings
from dataclasses import field from dataclasses import field
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
@ -59,6 +60,19 @@ class StructuredOutputsParams:
f"but multiple are specified: {self.__dict__}") f"but multiple are specified: {self.__dict__}")
@dataclass
class GuidedDecodingParams(StructuredOutputsParams):
def __post_init__(self):
warnings.warn(
"GuidedDecodingParams is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"StructuredOutputsParams instead.",
DeprecationWarning,
stacklevel=2)
return super().__post_init__()
class RequestOutputKind(Enum): class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput # Return entire output so far in every RequestOutput
CUMULATIVE = 0 CUMULATIVE = 0
@ -179,6 +193,8 @@ class SamplingParams(
# Fields used to construct logits processors # Fields used to construct logits processors
structured_outputs: Optional[StructuredOutputsParams] = None structured_outputs: Optional[StructuredOutputsParams] = None
"""Parameters for configuring structured outputs.""" """Parameters for configuring structured outputs."""
guided_decoding: Optional[GuidedDecodingParams] = None
"""Deprecated alias for structured_outputs."""
logit_bias: Optional[dict[int, float]] = None logit_bias: Optional[dict[int, float]] = None
"""If provided, the engine will construct a logits processor that applies """If provided, the engine will construct a logits processor that applies
these logit biases.""" these logit biases."""
@ -227,6 +243,7 @@ class SamplingParams(
ge=-1)]] = None, ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: Optional[StructuredOutputsParams] = None, structured_outputs: Optional[StructuredOutputsParams] = None,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None, allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None, extra_args: Optional[dict[str, Any]] = None,
@ -238,6 +255,15 @@ class SamplingParams(
int(token): min(100.0, max(-100.0, bias)) int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items() for token, bias in logit_bias.items()
} }
if guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2)
structured_outputs = guided_decoding
guided_decoding = None
return SamplingParams( return SamplingParams(
n=1 if n is None else n, n=1 if n is None else n,
@ -334,6 +360,16 @@ class SamplingParams(
# eos_token_id is added to this by the engine # eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids) self._all_stop_token_ids.update(self.stop_token_ids)
if self.guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2)
self.structured_outputs = self.guided_decoding
self.guided_decoding = None
def _verify_args(self) -> None: def _verify_args(self) -> None:
if not isinstance(self.n, int): if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of " raise ValueError(f"n must be an int, but is of "

View File

@ -1383,6 +1383,38 @@ def find_nccl_library() -> str:
return so_file return so_file
def find_nccl_include_paths() -> Optional[list[str]]:
"""
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
environment variable, or we find the library file brought by
nvidia-nccl-cuXX. load_inline by default uses
torch.utils.cpp_extension.include_paths
"""
paths: list[str] = []
inc = envs.VLLM_NCCL_INCLUDE_PATH
if inc and os.path.isdir(inc):
paths.append(inc)
try:
import importlib.util
spec = importlib.util.find_spec("nvidia.nccl")
if spec and getattr(spec, "submodule_search_locations", None):
for loc in spec.submodule_search_locations:
inc_dir = os.path.join(loc, "include")
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
paths.append(inc_dir)
except Exception:
pass
seen = set()
out: list[str] = []
for p in paths:
if p and p not in seen:
out.append(p)
seen.add(p)
return out or None
prev_set_stream = torch.cuda.set_stream prev_set_stream = torch.cuda.set_stream
_current_stream_tls = threading.local() _current_stream_tls = threading.local()

View File

@ -9,7 +9,7 @@ from __future__ import annotations
import functools import functools
import importlib import importlib
import os import os
from typing import Any, Callable, NoReturn from typing import Any, Callable, NoReturn, Optional
import torch import torch
@ -172,9 +172,13 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
return 1 - sim return 1 - sim
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, def should_use_deepgemm_for_fp8_linear(
weight: torch.Tensor): output_dtype: torch.dtype,
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 weight: torch.Tensor,
supports_deep_gemm: Optional[bool] = None):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (supports_deep_gemm and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)

View File

@ -8,6 +8,8 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -291,7 +293,13 @@ class TritonAttentionImpl(AttentionImpl):
if self.kv_sharing_target_layer_name is None: if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash( if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key, key,
value, value,
key_cache, key_cache,
@ -303,8 +311,9 @@ class TritonAttentionImpl(AttentionImpl):
) )
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) if key_cache.dtype != self.fp8_dtype:
value_cache = value_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, \ assert layer._q_scale_float == 1.0, \
"A non 1.0 q_scale is not currently supported." "A non 1.0 q_scale is not currently supported."

View File

@ -107,19 +107,57 @@ def _make_metadata_with_slice(
the requests included in ubatch_slice the requests included in ubatch_slice
""" """
assert not ubatch_slice.is_empty(), (
f"Ubatch slice {ubatch_slice} is empty")
request_slice = ubatch_slice.request_slice request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \
"Token slice start outside of first request"
assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \
"Token slice end outside of last request"
# If the "middle" request has tokens in both ubatches, we have to split it.
# If ubatch_slice is the first ubatch then we will be splitting the last
# request. If it's the second microbatch, then we will be splitting the
# first request
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
request_slice) request_slice)
assert len(query_start_loc) >= 2, ( assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, " f"query_start_loc must have at least 2 elements, "
f"got {len(query_start_loc)}") f"got {len(query_start_loc)}")
query_start_loc_cpu = slice_query_start_locs(
attn_metadata.query_start_loc_cpu, request_slice)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice] seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max()) max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
request_slice] request_slice]
@ -167,6 +205,7 @@ def split_attn_metadata(
for ubatch_slice in ubatch_slices: for ubatch_slice in ubatch_slices:
results.append( results.append(
_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) _make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results return results
@ -696,7 +735,6 @@ def split_decodes_and_prefills(
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item() first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill num_decodes = first_prefill
num_prefills = num_reqs - num_decodes num_prefills = num_reqs - num_decodes

View File

@ -9,6 +9,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
@ -30,7 +31,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__) logger = init_logger(__name__)
@ -78,6 +78,8 @@ class EagleProposer:
self.is_multimodal_model = vllm_config.model_config \ self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model .is_multimodal_model
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and == CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager) not self.vllm_config.model_config.enforce_eager)
@ -118,7 +120,7 @@ class EagleProposer:
with_numpy=True) with_numpy=True)
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] self.allowed_attn_types: tuple[type, ...]
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
@ -191,11 +193,12 @@ class EagleProposer:
assert self.runner is not None assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups # Select the correct attention metadata builders for EAGLE layers.
ubatch_id = dbo_current_ubatch_id() # Get the attention metadata builders once and reuse for later.
attn_metadata_builder = \ builder = (self._get_attention_metadata_builder()
self.runner.attn_groups[0][0].metadata_builders[ubatch_id] if self.attn_metadata_builder is None else
attn_metadata = attn_metadata_builder.build_for_drafting( self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0) common_attn_metadata=common_attn_metadata, draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
@ -329,11 +332,9 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID) exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata # Rebuild attention metadata
attn_metadata_builder = \ attn_metadata = builder.build_for_drafting(
self.runner.attn_groups[0][0].metadata_builders[ubatch_id] common_attn_metadata=common_attn_metadata,
attn_metadata = attn_metadata_builder\ draft_index=token_index + 1)
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
@ -538,9 +539,8 @@ class EagleProposer:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
ubatch_id = dbo_current_ubatch_id()
tree_attn_metadata_builder = \ tree_attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id] self.runner.attn_groups[0][0].get_metadata_builder()
assert isinstance(tree_attn_metadata_builder, assert isinstance(tree_attn_metadata_builder,
TreeAttentionMetadataBuilder) TreeAttentionMetadataBuilder)
@ -854,10 +854,24 @@ class EagleProposer:
# share lm_head with the target model if needed # share lm_head with the target model if needed
# some model definition do not define lm_head explicitly # some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \ if self.vllm_config.speculative_config.method != "eagle3":
hasattr(target_language_model, "lm_head"): if hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.") logger.info(
self.model.lm_head = target_language_model.lm_head "Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
else:
if (hasattr(self.model, "lm_head")
and hasattr(target_language_model, "lm_head")
and self.model.lm_head.weight.shape
== target_language_model.lm_head.weight.shape):
logger.info("Assuming the EAGLE head shares the same lm_head"
" with the target model.")
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
else:
logger.info(
"The EAGLE head's lm_head will be loaded separately"
" from the target model.")
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
@ -880,6 +894,31 @@ class EagleProposer:
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
def _get_attention_metadata_builder(
self) -> list[AttentionMetadataBuilder]:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers.")
return builder
def validate_same_kv_cache_group(self, def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None: kv_cache_config: KVCacheConfig) -> None:
""" """

View File

@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin) KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
ubatch_split)
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_padded = num_tokens_unpadded + self.get_local_padding( num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
num_tokens_unpadded) num_tokens_unpadded)
ubatch_slices, num_tokens_after_padding = \ ubatch_slices, num_tokens_after_padding = \
ubatch_split(max_num_scheduled_tokens, ubatch_split(num_scheduled_tokens,
num_tokens_unpadded, num_tokens_unpadded,
num_tokens_padded, num_tokens_padded,
self.vllm_config) self.vllm_config)
@ -1176,9 +1177,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
encoder_seq_lens=encoder_seq_lens, encoder_seq_lens=encoder_seq_lens,
) )
if self.speculative_config and \ if (self.speculative_config
spec_decode_common_attn_metadata is None: and spec_decode_common_attn_metadata is None):
spec_decode_common_attn_metadata = common_attn_metadata if isinstance(self.drafter, EagleProposer):
if (self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names):
spec_decode_common_attn_metadata = common_attn_metadata
else:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]: for attn_group in self.attn_groups[kv_cache_group_id]:
# Prepare for cascade attention if enabled & beneficial. # Prepare for cascade attention if enabled & beneficial.
@ -1206,7 +1212,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
ubatch_slices, common_attn_metadata) ubatch_slices, common_attn_metadata)
for ubid, common_attn_metadata in enumerate( for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list): common_attn_metadata_list):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = (attn_group.get_metadata_builder( attn_metadata_i = (attn_group.get_metadata_builder(
ubatch_id=ubid).build( ubatch_id=ubid).build(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
@ -2182,9 +2187,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) = self._preprocess(scheduler_output, intermediate_tensors, ) = self._preprocess(scheduler_output, intermediate_tensors,
ubatch_slices, num_tokens_after_padding) ubatch_slices, num_tokens_after_padding)
if ubatch_slices is not None:
num_input_tokens = num_input_tokens // 2
uniform_decode = (max_query_len uniform_decode = (max_query_len
== self.uniform_decode_query_len) and ( == self.uniform_decode_query_len) and (
num_scheduled_tokens num_scheduled_tokens
@ -2194,6 +2196,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode, batch_descriptor = \ cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor) self.cudagraph_dispatcher.dispatch(batch_descriptor)
# This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens`
if ubatch_slices is not None:
num_input_tokens = ubatch_slices[0].num_tokens
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with (set_forward_context( with (set_forward_context(
@ -2360,7 +2367,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor], aux_hidden_states: Optional[list[torch.Tensor]],
spec_decode_metadata: Optional[SpecDecodeMetadata], spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> Union[list[list[int]], torch.Tensor]: ) -> Union[list[list[int]], torch.Tensor]:
@ -2380,6 +2387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
indices = [] indices = []
offset = 0 offset = 0
assert spec_decode_metadata is not None
for num_draft, tokens in zip( for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, spec_decode_metadata.num_draft_tokens,
sampled_token_ids): sampled_token_ids):
@ -2430,6 +2438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(woosuk): Support M-RoPE. # TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[:num_scheduled_tokens] target_positions = self.positions.gpu[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat( target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states], [h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1) dim=-1)
@ -2455,6 +2464,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(woosuk): Support M-RoPE. # TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[token_indices] target_positions = self.positions.gpu[token_indices]
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat( target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1) [h[token_indices] for h in aux_hidden_states], dim=-1)
else: else:
@ -2821,7 +2831,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False, force_attention: bool = False,
uniform_decode: bool = False, uniform_decode: bool = False,
allow_microbatching: bool = False, allow_microbatching: bool = True,
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
create_mixed_batch: bool = False, create_mixed_batch: bool = False,
@ -2847,32 +2857,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests. (1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
""" """
ubatch_enabled = self.parallel_config.enable_dbo
num_tokens_across_dp = None
num_pad = 0
should_ubatch = False
if ubatch_enabled:
should_ubatch = num_tokens >= \
self.parallel_config.dbo_decode_token_threshold and \
allow_microbatching
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
num_tokens, num_tokens, should_ubatch, self.vllm_config)
# Currently the dummy run should only be ubatching during
# cuda graph capture, meaning all DP ranks should already
# have the same batch size
if num_tokens_across_dp is not None:
assert int(num_tokens_across_dp[0]) == num_tokens // 2
assert cudagraph_runtime_mode in { assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
} }
if not should_ubatch:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
# If cudagraph_mode.decode_mode() == FULL and # If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using # cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs. # different graphs and/or modes for mixed prefill-decode batches vs.
@ -2888,10 +2876,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# for GQA/MQA. # for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \ max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens num_tokens
if allow_microbatching:
assert self.uniform_decode_query_len == 1
assert uniform_decode is True
assert max_query_len == 1
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
@ -2916,7 +2900,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert not create_mixed_batch assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len) num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \ assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch" f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0: if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len num_scheduled_tokens_list[-1] = num_tokens % max_query_len
@ -2930,20 +2916,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32) dtype=np.int32)
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
ubatch_slices = None ubatch_slices = None
num_tokens_after_padding = None
# We currently only microbatch if the number of tokens is # We currently only microbatch if the number of tokens is
# over a certain threshold. # over a certain threshold.
if should_ubatch: if self.parallel_config.enable_dbo and allow_microbatching:
# We only support decode-only cudagraphs ubatch_slices, num_tokens_after_padding = ubatch_split(
assert num_reqs == num_tokens num_scheduled_tokens,
assert num_tokens % 2 == 0 total_num_scheduled_tokens,
ubatch_slices = [ total_num_scheduled_tokens,
UBatchSlice(slice(0, num_reqs // 2), slice(0, self.vllm_config,
num_tokens // 2)), )
UBatchSlice(slice(num_reqs // 2, num_reqs),
slice(num_tokens // 2, num_tokens)) # If we failed to microbatch, currently need to resynchronize
] # TODO(lucas,sage): we should be able to avoid this second sync by
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
# a single `coordinate_batch_across_dp` function.
if num_tokens_after_padding is None:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens_after_padding = num_tokens + num_pad
else:
num_tokens_across_dp = num_tokens_after_padding
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
attn_metadata: Optional[PerLayerAttnMetadata] = None attn_metadata: Optional[PerLayerAttnMetadata] = None
@ -2960,12 +2957,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(luka) better system for describing dummy batches # TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else: else:
# Make sure max_model_len is used at the graph capture time. seq_lens = max_query_len
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0 self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu() self.seq_lens.copy_to_gpu()
cum_num_tokens, _ = self._get_cumsum_and_arange(
num_scheduled_tokens)
self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu()
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups): self.kv_cache_config.kv_cache_groups):
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
@ -3060,7 +3061,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_randomize_inputs(input_ids), set_forward_context( with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens_after_padding,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
@ -3395,56 +3396,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
desc="Capturing CUDA graphs ({}, {})".format( desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode", "decode" if uniform_decode else "mixed prefill-decode",
cudagraph_runtime_mode.name)) cudagraph_runtime_mode.name))
enable_dbo = self.parallel_config.enable_dbo
# DBO Only supports running Full cudagraphs with uniform
# decode lengths
if enable_dbo and uniform_decode:
for num_tokens in compilation_cases:
# If the number of tokens is greater than the microbatching
# threshold, don't generate a microbatched cudagraph
if (num_tokens
< self.parallel_config.dbo_decode_token_threshold):
continue
# Warmup # We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph and for uniform decode batches.
capture_ubatched_graph = self.parallel_config.enable_dbo \
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
and uniform_decode \
and check_ubatch_thresholds(
config=self.vllm_config.parallel_config,
num_tokens=num_tokens,
uniform_decode=uniform_decode,
)
# Currently we capture both microbatched and non-microbatched
# graphs when capture_ubatched_graph is True, this is because
# occasionally we will be forced out of microbatching due to other
# DP ranks not microbatching (usually caused by an empty second
# microbatch; once we resolve this, we can remove the
# non-microbatched graph capture).
allow_microbatching_options = [True, False] if \
capture_ubatched_graph else [False]
for allow_microbatching in allow_microbatching_options:
for _ in range( for _ in range(
self.compilation_config.cudagraph_num_of_warmups): self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = ( force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL) cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens, self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention, force_attention=force_attention,
uniform_decode=True, uniform_decode=uniform_decode,
allow_microbatching=True, allow_microbatching=allow_microbatching,
skip_eplb=True) skip_eplb=True,
remove_lora=False)
# Graph Capture
self._dummy_run(num_tokens, self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.FULL, cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=True,
allow_microbatching=True,
skip_eplb=True)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True, skip_eplb=True,
remove_lora=False) remove_lora=False)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False)
self.maybe_remove_all_loras(self.lora_config) self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
@ -3500,24 +3496,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_groups: list[AttentionGroup] = [] attn_groups: list[AttentionGroup] = []
for (attn_backend, for (attn_backend,
kv_cache_spec), layer_names in attn_backends_map.items(): kv_cache_spec), layer_names in attn_backends_map.items():
attn_metadata_builders = [] attn_group = AttentionGroup.create_with_metadata_builders(
attn_metadata_builders.append(attn_backend.get_builder_cls()( attn_backend,
kv_cache_spec,
layer_names, layer_names,
kv_cache_spec,
self.vllm_config, self.vllm_config,
self.device, self.device,
)) num_metadata_builders=1
if self.parallel_config.enable_dbo: if not self.parallel_config.enable_dbo else 2,
attn_metadata_builders.append( )
attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
))
attn_group = AttentionGroup(attn_backend,
attn_metadata_builders,
layer_names, kv_cache_spec)
attn_groups.append(attn_group) attn_groups.append(attn_group)
return attn_groups return attn_groups
@ -3562,6 +3550,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
CUDAGraphMode.FULL_DECODE_ONLY CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg) logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if (self.compilation_config.level == CompilationLevel.PIECEWISE and
(self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition)):
msg += "; setting cudagraph_mode=PIECEWISE because "\
"attention is compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE because "\
"attention is not compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is # check that if we are doing spec-decode + decode full-cudagraphs it is
# supported # supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL

View File

@ -1,25 +1,30 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import threading import threading
from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import torch import torch
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.forward_context import (create_forward_context, get_forward_context, from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context) override_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import has_deep_gemm
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclasses.dataclass @dataclass
class UbatchMetadata: class UbatchMetadata:
context: UBatchContext context: UBatchContext
input_ids: torch.Tensor input_ids: torch.Tensor
@ -29,13 +34,55 @@ class UbatchMetadata:
num_tokens: int num_tokens: int
@dataclasses.dataclass @dataclass
class CUDAGraphMetaData: class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph cudagraph: torch.cuda.CUDAGraph
ubatch_metadata: UbatchMetadata ubatch_metadata: UbatchMetadata
outputs: Optional[Any] = None outputs: Optional[Any] = None
class SMControlContextManager:
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
set_compute_sms: Callable[[int], None]):
"""
Context manager for controlling SM (Streaming Multiprocessor)
allocation. Upon entering the context, it sets the number of SMs
allocated for communication and computation to comm_sms and
total_sms - comm_sms respectively. Upon exiting, it restores the
allocation to use all available SMs (i.e. total_sms).
Args:
comm_sms (int): The number of SMs to allocate for communication.
(The remainder will be used for computation.)
set_comm_sms (Callable[[int], None]):
A function that sets the number of SMs for communication.
set_compute_sms (Callable[[int], None]):
A function that sets the number of SMs for computation.
"""
assert current_platform.is_cuda(), \
"SM control is currently only supported on CUDA"
props = torch.cuda.get_device_properties(torch.cuda.current_device())
total_sms = props.multi_processor_count
assert comm_sms < total_sms
self.total_sms = total_sms
self.compute_sms = total_sms - comm_sms
self.comm_sms = comm_sms
self.set_comm_sms = set_comm_sms
self.set_compute_sms = set_compute_sms
def __enter__(self):
self.set_comm_sms(self.comm_sms)
self.set_compute_sms(self.compute_sms)
def __exit__(self, exc_type, exc_value, traceback):
self.set_comm_sms(self.total_sms)
self.set_compute_sms(self.total_sms)
class UBatchWrapper: class UBatchWrapper:
def __init__(self, runnable: Callable, vllm_config: VllmConfig, def __init__(self, runnable: Callable, vllm_config: VllmConfig,
@ -56,6 +103,35 @@ class UBatchWrapper:
runnable, vllm_config, runtime_mode=runtime_mode) runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool() self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
all2all_manager = get_ep_group(
).device_communicator.all2all_manager
if all2all_manager.max_sms_used() is not None:
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
if comm_sms > 0:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
import deep_gemm as dg
set_compute_sms = lambda sms: dg.set_num_sms(sms)
return SMControlContextManager(comm_sms=comm_sms,
set_comm_sms=set_comm_sms,
set_compute_sms=set_compute_sms)
def __getattr__(self, key: str): def __getattr__(self, key: str):
# allow accessing the attributes of the runnable. # allow accessing the attributes of the runnable.
if hasattr(self.runnable, key): if hasattr(self.runnable, key):
@ -132,6 +208,10 @@ class UBatchWrapper:
cudagraph=torch.cuda.CUDAGraph(), cudagraph=torch.cuda.CUDAGraph(),
ubatch_metadata=ubatch_metadata, ubatch_metadata=ubatch_metadata,
) )
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
with torch.cuda.graph(cudagraph_metadata.cudagraph, with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream, stream=compute_stream,
pool=self.graph_pool): pool=self.graph_pool):
@ -282,8 +362,8 @@ class UBatchWrapper:
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE) cudagraph_runtime_mode=CUDAGraphMode.NONE)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model) return self._capture_ubatches(ubatch_metadata, self.model)
elif num_tokens in self.cudagraphs: elif num_tokens in self.cudagraphs:
cudagraph_metadata = self.cudagraphs[num_tokens] cudagraph_metadata = self.cudagraphs[num_tokens]
cudagraph_metadata.cudagraph.replay() cudagraph_metadata.cudagraph.replay()
@ -300,4 +380,5 @@ class UBatchWrapper:
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE) cudagraph_runtime_mode=CUDAGraphMode.NONE)
return self._run_ubatches(ubatch_metadata, self.model) with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

View File

@ -386,11 +386,13 @@ class Worker(WorkerBase):
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory=" f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " f"{kv_cache_memory_bytes_to_requested_limit}` "
f"requested memory, or `--kv-cache-memory=" f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " f"into requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` "
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
f"utilize gpu memory. Current kv cache memory in use is " f"utilize gpu memory. Current kv cache memory in use is "
f"{int(self.available_kv_cache_memory_bytes)} bytes.") f"{GiB(self.available_kv_cache_memory_bytes)} GiB.")
logger.debug(msg) logger.debug(msg)

View File

@ -3,9 +3,10 @@
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
from vllm.config import VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.forward_context import DPMetadata from vllm.forward_context import DPMetadata
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import round_up from vllm.utils import round_up
@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens(
dp_size, dp_rank) dp_size, dp_rank)
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
uniform_decode: bool) -> bool:
if not config.enable_dbo:
return False
if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold
else:
return num_tokens >= config.dbo_prefill_token_threshold
def get_dp_padding_ubatch( def get_dp_padding_ubatch(
num_tokens_unpadded: int, num_tokens_padded: int, num_tokens_unpadded: int, num_tokens_padded: int,
should_attempt_ubatching: bool, should_attempt_ubatching: bool,
@ -95,9 +106,37 @@ def get_dp_padding_ubatch(
dtype=torch.int32) dtype=torch.int32)
return should_ubatch, num_tokens_after_padding return should_ubatch, num_tokens_after_padding
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
-> UBatchSlices:
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point)
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
# Determine request slices using exclusive stop semantics
# First ubatch includes requests whose tokens overlap [0, split_point)
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left"))
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point
# or the request starting exactly at split_point (if on boundary)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
second_ubatch_req_slice = slice(second_ubatch_req_start,
len(cu_num_tokens) - 1)
return [
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
]
def ubatch_split( def ubatch_split(
max_num_scheduled_tokens: int, num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int, num_tokens_unpadded: int,
num_tokens_padded: int, num_tokens_padded: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@ -122,17 +161,20 @@ def ubatch_split(
return (None, None) return (None, None)
# Check preconditions for microbatching # Check preconditions for microbatching
should_attempt_ubatching = \ should_attempt_ubatching = check_ubatch_thresholds(
parallel_config.enable_dbo and \ parallel_config,
num_tokens_unpadded >= \ num_tokens_unpadded,
parallel_config.dbo_decode_token_threshold \ vllm_config,
and max_num_scheduled_tokens == 1 )
# Don't microbatch unless every other DP worker is also microbatching # Don't microbatch unless every other DP worker is also microbatching
num_tokens_after_padding = None should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch( num_tokens_unpadded,
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, num_tokens_padded,
vllm_config) should_attempt_ubatching,
vllm_config,
)
if not should_ubatch: if not should_ubatch:
return (None, None) return (None, None)
@ -141,15 +183,9 @@ def ubatch_split(
# to the second ubatch in pad_out_ubatch_slice after attention # to the second ubatch in pad_out_ubatch_slice after attention
# metadata creation # metadata creation
assert num_tokens_after_padding is not None assert num_tokens_after_padding is not None
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item()) token_split_point = int(num_tokens_after_padding[0].item())
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
num_tokens_unpadded)
# Note there's an assumption here that there's 1 token per request ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
ubatch_slices = [ token_split_point)
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
]
return (ubatch_slices, num_tokens_after_padding) return (ubatch_slices, num_tokens_after_padding)

View File

@ -10,6 +10,14 @@ class UBatchSlice:
request_slice: slice request_slice: slice
token_slice: slice token_slice: slice
def is_empty(self) -> bool:
return self.request_slice.start == self.request_slice.stop \
or self.token_slice.start == self.token_slice.stop
@property
def num_tokens(self) -> int:
return self.token_slice.stop - self.token_slice.start
UBatchSlices: TypeAlias = list[UBatchSlice] UBatchSlices: TypeAlias = list[UBatchSlice]

View File

@ -51,8 +51,8 @@ class UBatchContext:
self.cpu_wait_event.wait() self.cpu_wait_event.wait()
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self._restore_context() self._restore_context()
# Assume we start on the compute stream # Assume we want to start on the compute stream
assert current_stream() == self.compute_stream self.update_stream(self.compute_stream)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
@ -62,17 +62,15 @@ class UBatchContext:
self.maybe_run_recv_hook() self.maybe_run_recv_hook()
self.cpu_signal_event.set() self.cpu_signal_event.set()
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self.current_stream = self.compute_stream
torch.cuda.set_stream(self.current_stream)
return False return False
def _restore_context(self): def _restore_context(self):
forward_context._forward_context = self.forward_context forward_context._forward_context = self.forward_context
torch.cuda.set_stream(self.current_stream)
def update_stream(self, stream): def update_stream(self, stream):
self.current_stream = stream self.current_stream = stream
torch.cuda.set_stream(self.current_stream) if current_stream() != self.current_stream:
torch.cuda.set_stream(self.current_stream)
def _signal_comm_done(self): def _signal_comm_done(self):
self.gpu_comm_done_event.record(self.comm_stream) self.gpu_comm_done_event.record(self.comm_stream)
@ -99,9 +97,20 @@ class UBatchContext:
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self._restore_context() self._restore_context()
def switch_to_comm(self):
self.update_stream(self.comm_stream)
def switch_to_compute(self):
self.update_stream(self.compute_stream)
def switch_to_comm_sync(self): def switch_to_comm_sync(self):
self._signal_compute_done() self._signal_compute_done()
self.update_stream(self.comm_stream) self.update_stream(self.comm_stream)
self._wait_compute_done()
def switch_to_compute_sync(self):
self._signal_comm_done()
self.update_stream(self.compute_stream)
self._wait_comm_done() self._wait_comm_done()
def maybe_run_recv_hook(self): def maybe_run_recv_hook(self):
@ -112,8 +121,7 @@ class UBatchContext:
def yield_(self): def yield_(self):
self.current_stream = current_stream() self.current_stream = current_stream()
self._cpu_yield() self._cpu_yield()
if self.current_stream != current_stream(): self.update_stream(self.current_stream)
self.update_stream(self.current_stream)
def yield_and_switch_from_compute_to_comm(self): def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream assert current_stream() == self.compute_stream
@ -153,15 +161,20 @@ def _register_ubatch_function(func):
return wrapper return wrapper
dbo_maybe_run_recv_hook = _register_ubatch_function(
UBatchContext.maybe_run_recv_hook)
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
UBatchContext.yield_and_switch_from_compute_to_comm) UBatchContext.yield_and_switch_from_compute_to_comm)
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
UBatchContext.yield_and_switch_from_comm_to_compute) UBatchContext.yield_and_switch_from_comm_to_compute)
dbo_yield = _register_ubatch_function(UBatchContext.yield_) dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
dbo_maybe_run_recv_hook = _register_ubatch_function( dbo_switch_to_compute = _register_ubatch_function(
UBatchContext.maybe_run_recv_hook) UBatchContext.switch_to_compute)
dbo_switch_to_comm_sync = _register_ubatch_function( dbo_switch_to_comm_sync = _register_ubatch_function(
UBatchContext.switch_to_comm_sync) UBatchContext.switch_to_comm_sync)
dbo_switch_to_compute_sync = _register_ubatch_function(
UBatchContext.switch_to_compute_sync)
def dbo_register_recv_hook(recv_hook): def dbo_register_recv_hook(recv_hook):

View File

@ -130,15 +130,32 @@ class MultiModalBudget:
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:
backend: type[AttentionBackend] backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str] layer_names: list[str]
kv_cache_spec: KVCacheSpec kv_cache_spec: KVCacheSpec
@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
num_metadata_builders: int = 1,
) -> 'AttentionGroup':
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
device)
for _ in range(num_metadata_builders)
]
return AttentionGroup(backend, metadata_builders, layer_names,
kv_cache_spec)
def get_metadata_builder(self, def get_metadata_builder(self,
ubatch_id: Optional[int] = None ubatch_id: int = 0) -> AttentionMetadataBuilder:
) -> AttentionMetadataBuilder:
if ubatch_id is None:
return self.metadata_builders[0]
assert len(self.metadata_builders) > ubatch_id assert len(self.metadata_builders) > ubatch_id
return self.metadata_builders[ubatch_id] return self.metadata_builders[ubatch_id]