mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 21:37:11 +08:00
Merge branch 'main' into woosuk/model-runner-v2
This commit is contained in:
commit
704def253c
@ -164,6 +164,7 @@ steps:
|
||||
- tests/v1/test_internal_lb_dp.py
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
- tests/distributed/test_symm_mem_allreduce.py
|
||||
commands:
|
||||
# test with torchrun tp=2 and external_dp=2
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
@ -188,6 +189,7 @@ steps:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s distributed/test_events.py
|
||||
- pytest -v -s distributed/test_symm_mem_allreduce.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- pushd ../examples/offline_inference
|
||||
@ -329,6 +331,8 @@ steps:
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||
|
||||
- label: Platform Tests (CUDA) # 4min
|
||||
timeout_in_minutes: 15
|
||||
@ -1037,3 +1041,4 @@ steps:
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
|
||||
|
||||
@ -103,10 +103,15 @@ start_server() {
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
|
||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||
fi
|
||||
local server_pid=$!
|
||||
|
||||
# wait for 10 minutes...
|
||||
server_started=0
|
||||
for i in {1..60}; do
|
||||
# This line checks whether the server is still alive or not,
|
||||
# since that we should always have permission to send signal to the server process.
|
||||
kill -0 $server_pid 2> /dev/null || break
|
||||
|
||||
RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout)
|
||||
STATUS_CODE=$(echo "$RESPONSE" | tail -n 1)
|
||||
if [[ "$STATUS_CODE" -eq 200 ]]; then
|
||||
@ -118,7 +123,7 @@ start_server() {
|
||||
done
|
||||
|
||||
if (( ! server_started )); then
|
||||
echo "server did not start within 10 minutes. Please check server log at $vllm_log".
|
||||
echo "server did not start within 10 minutes or crashed. Please check server log at $vllm_log".
|
||||
return 1
|
||||
else
|
||||
return 0
|
||||
|
||||
@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||
|
||||
@ -158,7 +158,7 @@ def bench_fp8(
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
||||
),
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
|
||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
|
||||
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
||||
),
|
||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
||||
|
||||
@ -7,6 +7,10 @@ Benchmark script for device communicators:
|
||||
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||
and SymmMemCommunicator (multimem, two-shot).
|
||||
|
||||
for NCCL symmetric memory you need to set the environment variables
|
||||
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
|
||||
not use fast NVLS implementation for all reduce.
|
||||
|
||||
Usage:
|
||||
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||
|
||||
@ -26,7 +30,13 @@ import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
register_nccl_symmetric_ops,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id,
|
||||
)
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -98,6 +108,7 @@ class CommunicatorBenchmark:
|
||||
)
|
||||
if not self.pynccl_comm.disabled:
|
||||
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||
else:
|
||||
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||
self.pynccl_comm = None
|
||||
@ -194,6 +205,15 @@ class CommunicatorBenchmark:
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
(
|
||||
"pynccl-symm",
|
||||
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_multimem is not None:
|
||||
comm = self.symm_mem_comm_multimem
|
||||
@ -271,7 +291,9 @@ class CommunicatorBenchmark:
|
||||
# Capture the graph using context manager
|
||||
with context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
graph_pool = torch.cuda.graph_pool_handle()
|
||||
set_graph_pool_id(graph_pool)
|
||||
with torch.cuda.graph(graph, pool=graph_pool):
|
||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
|
||||
@ -9,6 +9,9 @@ import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
@ -31,6 +34,8 @@ def run_benchmark(
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
num_iters: int,
|
||||
implementation: str,
|
||||
benchmark_mode: str,
|
||||
device: str = "cuda",
|
||||
) -> float:
|
||||
"""Return latency (seconds) for given num_tokens."""
|
||||
@ -38,6 +43,14 @@ def run_benchmark(
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||
|
||||
if implementation not in ("cuda", "triton"):
|
||||
raise ValueError(
|
||||
f"Unsupported implementation: {implementation}. "
|
||||
"Only 'cuda' and 'triton' are supported."
|
||||
)
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
return float("nan") # Triton does not support HND layout yet.
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -65,27 +78,49 @@ def run_benchmark(
|
||||
cache_layout=kv_cache_layout,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
# to free unused memory
|
||||
del key_caches, value_caches
|
||||
|
||||
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
if implementation == "cuda":
|
||||
function_under_test = lambda: ops.reshape_and_cache_flash(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
function_under_test = lambda: triton_reshape_and_cache_flash(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
if benchmark_mode == "cudagraph":
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
function_under_test = lambda: g.replay()
|
||||
|
||||
def run_cuda_benchmark(n_iters: int) -> float:
|
||||
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
return (end - start) / n_iters
|
||||
|
||||
@ -116,10 +151,16 @@ def main(args):
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
kv_cache_layout=layout,
|
||||
num_iters=args.iters,
|
||||
implementation=args.implementation,
|
||||
benchmark_mode=args.mode,
|
||||
device="cuda",
|
||||
)
|
||||
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
||||
|
||||
print(
|
||||
f"Benchmark results for implementation {args.implementation}"
|
||||
f" (measuring with {args.mode}):"
|
||||
)
|
||||
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
||||
|
||||
|
||||
@ -151,6 +192,21 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument("--iters", type=int, default=100)
|
||||
|
||||
parser.add_argument(
|
||||
"--implementation",
|
||||
type=str,
|
||||
choices=["cuda", "triton"],
|
||||
default="cuda",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["cudagraph", "no_graph"],
|
||||
default="cudagraph",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
|
||||
@ -59,7 +59,7 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
return w8a8_triton_block_scaled_mm(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
|
||||
@ -418,6 +418,15 @@ __device__ inline T neg_inf() {
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline bool is_finite(const T val) {
|
||||
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
|
||||
return cuda::std::isfinite(val);
|
||||
#else
|
||||
return isfinite(cuda_cast<float, T>(val));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output, T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
// The check is necessary to avoid abnormal input
|
||||
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
|
||||
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) &&
|
||||
cuda::std::isfinite(scores_with_bias[offset + i])
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
T candidates = (i < num_experts_per_group) &&
|
||||
is_finite(scores_with_bias[offset + i])
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
|
||||
@ -12,8 +12,8 @@
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include "../../dispatch_utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = 0xffff;
|
||||
__device__ __forceinline__ float GroupReduceMax(float val) {
|
||||
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
||||
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
|
||||
@ -5,11 +5,14 @@
|
||||
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
const int64_t rows_per_block);
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias,
|
||||
const int64_t CuCount);
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount);
|
||||
|
||||
void paged_attention(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
|
||||
@ -292,8 +292,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
@ -484,7 +485,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][i] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
@ -529,7 +537,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
if (BIAS)
|
||||
sum4[n][i][0] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
@ -541,8 +551,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -553,8 +565,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
@ -772,8 +785,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
if (commitColumn[i]) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][i] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -818,8 +840,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
if (commitColumn[i]) {
|
||||
if (BIAS)
|
||||
sum4[n][i][0] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -842,8 +868,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -854,8 +882,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
#if defined(__HIP__MI3XX__)
|
||||
@ -1124,8 +1153,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
if (commitColumn[i])
|
||||
if (commitColumn[i]) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][i] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1166,8 +1204,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 63) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = 0; i < YTILE; i++) {
|
||||
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
if (commitColumn[i]) {
|
||||
if (BIAS)
|
||||
sum4[n][i][0] +=
|
||||
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
|
||||
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1190,8 +1232,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N>
|
||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A, scalar_t* C,
|
||||
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -1226,11 +1270,20 @@ int mindiv(int N, int div1, int div2) {
|
||||
return rtn;
|
||||
}
|
||||
|
||||
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias,
|
||||
const int64_t CuCount) {
|
||||
auto M_in = in_a.size(0);
|
||||
auto K_in = in_a.size(1);
|
||||
auto N_in = in_b.size(0);
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
: 1;
|
||||
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
|
||||
in_bias->sizes().size() == 2)
|
||||
? in_bias->size(0)
|
||||
: 1;
|
||||
|
||||
TORCH_CHECK(in_a.dtype() == in_b.dtype());
|
||||
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
|
||||
@ -1254,18 +1307,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
|
||||
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
|
||||
CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, c, __wvPrGrp, CuCount); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -1273,6 +1326,10 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
using fptype = typename scalar<scalar_t>::type;
|
||||
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
|
||||
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
|
||||
const fptype* biasf4 =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
||||
: nullptr;
|
||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
@ -1300,8 +1357,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
const int CuCount) {
|
||||
@ -1453,7 +1511,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (threadIdx.x == 0) {
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1465,7 +1533,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS,
|
||||
scalar_t* C, const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
@ -1477,8 +1547,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_(const int K, const int Kp, const int M, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A, scalar_t* C,
|
||||
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
constexpr int max_lds_len = LDS_SIZE;
|
||||
@ -1626,7 +1697,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0] * sA * sB);
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1638,16 +1718,19 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
scalar_t* C, const float* __restrict__ s_A,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
const int CuCount) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
at::Tensor& scale_a, at::Tensor& scale_b,
|
||||
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
@ -1656,6 +1739,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
auto K_in = in_a.size(1);
|
||||
auto N_in = in_b.size(0);
|
||||
auto Kp_in = in_a.stride(0);
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
: 1;
|
||||
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
|
||||
in_bias->sizes().size() == 2)
|
||||
? in_bias->size(0)
|
||||
: 1;
|
||||
|
||||
TORCH_CHECK(K_in % 16 == 0, "k % 16 == 0");
|
||||
TORCH_CHECK(in_a.dtype() == in_b.dtype() && in_a.dtype() == kFp8Type);
|
||||
TORCH_CHECK(out_c.dtype() == torch::kFloat16 ||
|
||||
@ -1673,13 +1765,15 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -1691,6 +1785,9 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||
VLLM_DISPATCH_FP8_TYPES(in_a.scalar_type(), "wvSplitKQ", [&] {
|
||||
auto a_ptr = in_a.data_ptr<fp8_t>();
|
||||
auto b_ptr = in_b.data_ptr<fp8_t>();
|
||||
auto bias_ptr = (in_bias.has_value() && in_bias->numel() > 0)
|
||||
? reinterpret_cast<fptype*>(in_bias->data_ptr())
|
||||
: nullptr;
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
|
||||
|
||||
@ -22,13 +22,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
|
||||
// Custom gemm op for skinny matrix-matrix multiplication
|
||||
rocm_ops.def(
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, int CuCount) -> "
|
||||
"wvSplitK(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
|
||||
"Tensor");
|
||||
rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK);
|
||||
|
||||
// wvSplitK for fp8
|
||||
rocm_ops.def(
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor! out_c, Tensor scale_a, "
|
||||
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
|
||||
"Tensor scale_a, "
|
||||
" Tensor scale_b, int CuCount) -> ()");
|
||||
rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ);
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts):
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument("--test", action="store_true")
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
@ -60,6 +61,7 @@ def parse_args():
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--enforce-eager", action="store_true")
|
||||
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
||||
parser.add_argument("--max-model-len", type=int, default=16384)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=-1)
|
||||
@ -71,8 +73,7 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
def main(args):
|
||||
args.endpoint_type = "openai-chat"
|
||||
|
||||
model_dir = args.model_dir
|
||||
@ -134,7 +135,7 @@ def main():
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=16384,
|
||||
max_model_len=args.max_model_len,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
)
|
||||
@ -198,6 +199,39 @@ def main():
|
||||
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
|
||||
print(f"acceptance at token {i}: {acceptance_rate:.2f}")
|
||||
|
||||
return acceptance_length
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
args = parse_args()
|
||||
acceptance_length = main(args)
|
||||
|
||||
if args.test:
|
||||
# takes ~30s to run on 1xH100
|
||||
assert args.method in ["eagle", "eagle3"]
|
||||
assert args.tp == 1
|
||||
assert args.num_spec_tokens == 3
|
||||
assert args.dataset_name == "hf"
|
||||
assert args.dataset_path == "philschmid/mt-bench"
|
||||
assert args.num_prompts == 80
|
||||
assert args.temp == 0
|
||||
assert args.top_p == 1.0
|
||||
assert args.top_k == -1
|
||||
assert args.enable_chunked_prefill
|
||||
|
||||
# check acceptance length is within 2% of expected value
|
||||
rtol = 0.02
|
||||
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811
|
||||
|
||||
assert (
|
||||
acceptance_length <= (1 + rtol) * expected_acceptance_length
|
||||
and acceptance_length >= (1 - rtol) * expected_acceptance_length
|
||||
), (
|
||||
f"acceptance_length {acceptance_length} is not "
|
||||
f"within {rtol * 100}% of {expected_acceptance_length}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Test passed! Expected AL: "
|
||||
f"{expected_acceptance_length}, got {acceptance_length}"
|
||||
)
|
||||
|
||||
94
tests/distributed/test_nccl_symm_mem_allreduce.py
Normal file
94
tests/distributed/test_nccl_symm_mem_allreduce.py
Normal file
@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
register_nccl_symmetric_ops)
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
get_nccl_mem_pool, is_symmetric_memory_enabled)
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44)
|
||||
|
||||
test_size_elements = 4 * 1024 * 1024
|
||||
|
||||
|
||||
def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
update_environment_variables({
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(CudaCommunicator,
|
||||
get_tp_group().device_communicator)
|
||||
pynccl_comm = cuda_communicator.pynccl_comm
|
||||
if get_nccl_mem_pool() is None:
|
||||
pytest.skip("NCCL allocator compilation failed "
|
||||
"(probably missing NCCL headers).")
|
||||
if not is_symmetric_memory_enabled():
|
||||
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
||||
|
||||
register_nccl_symmetric_ops(pynccl_comm)
|
||||
input = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
input_clone = input.clone()
|
||||
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
||||
assert output is not None
|
||||
|
||||
group = get_tp_group().device_group
|
||||
dist.all_reduce(input_clone, group=group)
|
||||
torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
||||
)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
|
||||
# Enable SymmMemCommunicator
|
||||
monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1")
|
||||
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
||||
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
||||
|
||||
mp.spawn(nccl_symm_mem_allreduce_worker,
|
||||
args=(world_size, ),
|
||||
nprocs=world_size)
|
||||
cleanup_dist_env_and_memory()
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import queue
|
||||
import random
|
||||
import typing
|
||||
|
||||
@ -10,26 +11,31 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group,
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44)
|
||||
|
||||
test_size_elements = 4 * 1024 * 1024
|
||||
test_size_elements = 1024 * 1024
|
||||
|
||||
|
||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
with monkeypatch.context() as m:
|
||||
config = VllmConfig(parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=world_size))
|
||||
|
||||
with monkeypatch.context() as m, set_current_vllm_config(config):
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
get_tp_group().device_communicator)
|
||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||
pytest.skip("SymmMemCommunicator is not available or disabled.")
|
||||
# can't use skip under multiprocessing
|
||||
q.put("SymmMemCommunicator is not available or disabled.")
|
||||
return
|
||||
|
||||
inp_direct_symm_mem = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||
pytest.skip(
|
||||
# can't use skip under multiprocessing
|
||||
q.put(
|
||||
"SymmMemCommunicator isn't used for this world and input size."
|
||||
)
|
||||
return
|
||||
|
||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
|
||||
assert out_direct_symm_mem is not None
|
||||
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
group = get_tp_group().device_group
|
||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||
torch.testing.assert_close(out_direct_symm_mem,
|
||||
original_inp_direct_symm_mem,
|
||||
@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
q = mp.get_context('spawn').Queue()
|
||||
mp.spawn(symm_mem_allreduce_worker,
|
||||
args=(world_size, q),
|
||||
nprocs=world_size)
|
||||
try:
|
||||
val = q.get(timeout=1)
|
||||
except queue.Empty:
|
||||
val = None
|
||||
finally:
|
||||
cleanup_dist_env_and_memory()
|
||||
if val is not None:
|
||||
pytest.skip(val)
|
||||
|
||||
# Enable SymmMemCommunicator
|
||||
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
|
||||
|
||||
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
|
||||
cleanup_dist_env_and_memory()
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
||||
world_size = 4
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
# Verify that the DataParallel runs without error
|
||||
engine_args = EngineArgs(model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
data_parallel_backend="mp")
|
||||
LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
@ -39,6 +39,8 @@ CUDA_DEVICES = [
|
||||
# We assume fp8 is always enabled for testing.
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
|
||||
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@ -223,6 +225,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
||||
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_flash(
|
||||
kv_cache_factory_flashinfer,
|
||||
@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
implementation: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
assert implementation in ["cuda", "triton"]
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
pytest.skip("Triton implementation only supports NHD layout.")
|
||||
|
||||
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
||||
# blocks and tokens to consume less memory.
|
||||
@ -298,12 +305,20 @@ def test_reshape_and_cache_flash(
|
||||
cloned_key_cache = key_cache_compact.clone()
|
||||
cloned_value_cache = value_cache_compact.clone()
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
if implementation == "cuda":
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
elif implementation == "triton":
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash)
|
||||
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
||||
@ -90,7 +90,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
|
||||
@ -20,9 +20,11 @@ from vllm.platforms import current_platform
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_size: int, seed: int) -> None:
|
||||
group_size: int, seed: int,
|
||||
use_ue8m0: bool) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (expected_num_groups, batch_size)
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
group_size = 64
|
||||
@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
rocm_per_tensor_w8a8_scaled_mm_impl)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
@ -49,6 +49,7 @@ NKM_FACTORS_WVSPLITK_FP8 = [
|
||||
(2, 512, 512),
|
||||
(3, 2048, 2048),
|
||||
(4, 4096, 4096),
|
||||
(4, 16400, 2048),
|
||||
# Extended FP8 dimensions not covered by WVSPLITK
|
||||
(1, 14336, 1024),
|
||||
(2, 24576, 2048),
|
||||
@ -67,6 +68,9 @@ SEEDS = [0]
|
||||
@torch.inference_mode()
|
||||
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
torch.manual_seed(seed)
|
||||
#TODO: Zero-centering the inputs causes errors for LLMM1!
|
||||
# Without that the numbers quickly saturate, and may
|
||||
# be giving false matches.
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda")
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda")
|
||||
|
||||
@ -85,11 +89,51 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda")
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda")
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
|
||||
|
||||
ref_out = torch.matmul(A, B.t())
|
||||
out = ops.wvSplitK(B, A, cu_count)
|
||||
ref_out = torch.nn.functional.linear(A, B)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
@ -103,8 +147,8 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda")
|
||||
B = torch.rand(m, k, device="cuda")
|
||||
A = torch.rand(n, k, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, device="cuda") - 0.5
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
@ -123,27 +167,27 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8")
|
||||
def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias):
|
||||
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda")
|
||||
B = torch.rand(m, k, device="cuda")
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, device="cuda") - .5) * xavier
|
||||
B = (torch.rand(m, k, device="cuda") - .5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None
|
||||
|
||||
output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a,
|
||||
scale_b, bias)
|
||||
ref_out = torch._scaled_mm(A,
|
||||
B.t(),
|
||||
out_dtype=dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
assert torch.allclose(output, ref_out, rtol=0.01)
|
||||
bias=BIAS)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count(), BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
@ -17,8 +17,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
from vllm.model_executor.layers.layernorm import (RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm, rms_norm)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
|
||||
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
|
||||
@pytest.mark.parametrize("use_cutlass", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
|
||||
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
|
||||
use_rocm_aiter_gemm_w8a8_blockscale: str,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
|
||||
use_rocm_aiter_gemm_w8a8_blockscale)
|
||||
|
||||
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
|
||||
int(use_rocm_aiter_gemm_w8a8_blockscale)))
|
||||
block_scale_func = dispatch_w8a8_blockscale_func(
|
||||
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
|
||||
if use_cutlass:
|
||||
assert block_scale_func == cutlass_scaled_mm
|
||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_gemm_w8a8_blockscale):
|
||||
assert block_scale_func == (
|
||||
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
|
||||
else:
|
||||
assert block_scale_func == w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@ -742,3 +745,35 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
|
||||
perplexity = llm.generate_prompt_perplexity([prompt])[0]
|
||||
print(perplexity)
|
||||
assert perplexity <= exp_perplexity
|
||||
|
||||
|
||||
def test_compressed_tensors_fp8_block_enabled(vllm_runner):
|
||||
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
|
||||
with vllm_runner(model_path) as llm:
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method,
|
||||
CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
||||
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
|
||||
W8A8BlockFp8LinearOp)
|
||||
|
||||
assert qkv_proj.weight.dtype is fp8_dtype
|
||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||
assert len(qkv_proj.weight.shape) == 2
|
||||
assert len(qkv_proj.weight_scale.shape) == 2
|
||||
|
||||
input_quant_op = \
|
||||
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
|
||||
assert isinstance(input_quant_op, QuantFP8)
|
||||
assert input_quant_op._forward_method == input_quant_op.forward_cuda
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
@ -5,11 +5,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import create_common_attn_metadata
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata)
|
||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
assert results[1].num_reqs == mid_point
|
||||
assert results[1].num_actual_tokens == mid_point
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||
[
|
||||
# Split in the middle of request 1
|
||||
([32, 40], [8, 8], 12, 2, 1),
|
||||
# Split inside the first request
|
||||
([32, 40], [8, 8], 4, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||
expected_first_reqs,
|
||||
expected_second_reqs):
|
||||
"""Test splitting a prefill across ubatches"""
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cpu")
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
|
||||
common = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
|
||||
qsl_np = common.query_start_loc_cpu.numpy()
|
||||
num_tokens = common.num_actual_tokens
|
||||
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
|
||||
assert len(ubatch_slices) == 2
|
||||
|
||||
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
|
||||
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
|
||||
|
||||
# Token counts match the split
|
||||
assert first_meta.num_actual_tokens == split_point
|
||||
assert second_meta.num_actual_tokens == num_tokens - split_point
|
||||
|
||||
# Number of requests per ubatch
|
||||
assert first_meta.num_reqs == expected_first_reqs
|
||||
assert second_meta.num_reqs == expected_second_reqs
|
||||
|
||||
# Identify which request is split and how many tokens are in the first chunk
|
||||
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
|
||||
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
|
||||
orig_q_lens = (common.query_start_loc_cpu[1:] -
|
||||
common.query_start_loc_cpu[:-1])
|
||||
|
||||
# Check query length continuity: first-chunk + second-chunk == original qlen
|
||||
# First ubatch last request query length
|
||||
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
|
||||
first_meta.query_start_loc_cpu[-2])
|
||||
# Second ubatch first request query length
|
||||
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
|
||||
second_meta.query_start_loc_cpu[0])
|
||||
assert qlen_first_last == tokens_in_first_chunk
|
||||
assert qlen_first_last + qlen_second_first == int(
|
||||
orig_q_lens[split_req_idx])
|
||||
|
||||
# Check seq_lens adjustments
|
||||
# Context lengths per original request
|
||||
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
|
||||
|
||||
# First ubatch: last request's seq_len should be
|
||||
# context + tokens_in_first_chunk
|
||||
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
|
||||
assert int(first_meta.seq_lens[-1]) == expected_seqlen
|
||||
|
||||
# For full preceding requests in first ubatch, seq_lens should match
|
||||
# originals
|
||||
for i in range(first_meta.num_reqs - 1):
|
||||
assert int(first_meta.seq_lens[i]) == seq_lens[i]
|
||||
|
||||
# Second ubatch: first request (continuation) seq_len should be full
|
||||
# original
|
||||
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
|
||||
# Any following full requests in second ubatch should match originals
|
||||
for j in range(1, second_meta.num_reqs):
|
||||
# Map to original request index
|
||||
orig_idx = split_req_idx + j
|
||||
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import fields
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@ -21,7 +22,8 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.sampling_params import (GuidedDecodingParams, SamplingParams,
|
||||
StructuredOutputsParams)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import TokenizerMode
|
||||
@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str:
|
||||
return json.loads(s)
|
||||
|
||||
|
||||
def test_guided_decoding_deprecated():
|
||||
with pytest.warns(DeprecationWarning,
|
||||
match="GuidedDecodingParams is deprecated.*"):
|
||||
guided_decoding = GuidedDecodingParams(json_object=True)
|
||||
|
||||
structured_outputs = StructuredOutputsParams(json_object=True)
|
||||
assert fields(guided_decoding) == fields(structured_outputs)
|
||||
|
||||
with pytest.warns(DeprecationWarning,
|
||||
match="guided_decoding is deprecated.*"):
|
||||
sp1 = SamplingParams(guided_decoding=guided_decoding)
|
||||
|
||||
with pytest.warns(DeprecationWarning,
|
||||
match="guided_decoding is deprecated.*"):
|
||||
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)
|
||||
|
||||
assert sp1 == sp2
|
||||
assert sp1.structured_outputs == guided_decoding
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, backend, tokenizer_mode, speculative_config",
|
||||
|
||||
@ -532,9 +532,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
]
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder)
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@ -659,9 +660,10 @@ def test_propose_tree(spec_token_tree):
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
]
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder)
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0,
|
||||
|
||||
@ -1447,17 +1447,24 @@ def LLMM1(a: torch.Tensor, b: torch.Tensor,
|
||||
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
|
||||
|
||||
|
||||
def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor:
|
||||
return torch.ops._rocm_C.wvSplitK(a, b, cu_count)
|
||||
def wvSplitK(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
cu_count: int,
|
||||
bias: torch.Tensor = None) -> torch.Tensor:
|
||||
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
|
||||
|
||||
|
||||
def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
cu_count: int) -> torch.Tensor:
|
||||
def wvSplitKQ(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cu_count: int,
|
||||
bias: torch.Tensor = None) -> torch.Tensor:
|
||||
out = torch.empty((b.shape[0], a.shape[0]),
|
||||
dtype=out_dtype,
|
||||
device=b.device)
|
||||
torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count)
|
||||
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
176
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
176
vllm/attention/ops/triton_reshape_and_cache_flash.py
Normal file
@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash(
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size]
|
||||
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_key_idx = token_idx * key_stride
|
||||
src_value_idx = token_idx * value_stride
|
||||
|
||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
|
||||
mask=tile_pos < (num_heads * head_size))
|
||||
if FP8_KV_CACHE:
|
||||
if key_load.dtype.is_fp8():
|
||||
key_tile = key_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
|
||||
mask=tile_pos < (num_heads * head_size))
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
key_cache_ptr + tgt_idx + tile_pos,
|
||||
key_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
tl.store(
|
||||
value_cache_ptr + tgt_idx + tile_pos,
|
||||
value_tile,
|
||||
mask=tile_pos < (num_heads * head_size),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
key_cache: torch.Tensor,
|
||||
# [num_blocks, block_size, num_heads, head_size]
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
num_tokens = key.shape[0]
|
||||
num_heads = key.shape[1]
|
||||
head_size = key.shape[2]
|
||||
block_size = key_cache.shape[1]
|
||||
n = num_heads * head_size
|
||||
|
||||
key_stride = key.stride()[0]
|
||||
value_stride = value.stride()[0]
|
||||
block_stride = key_cache.stride()[0]
|
||||
page_stride = key_cache.stride()[1]
|
||||
|
||||
head_stride = key_cache.stride()[2]
|
||||
assert head_stride == head_size, "only continous heads are supported"
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
|
||||
kv_cache_dtype.startswith("fp8") else key_cache.dtype
|
||||
|
||||
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
|
||||
"fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
key_cache = key_cache.view(kv_cache_torch_dtype)
|
||||
value_cache = value_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
|
||||
"uint8 is not supported by triton reshape_and_cache_flash"
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
|
||||
torch.float8_e4m3fnuz], \
|
||||
"unsupported dtype of KV cache tensor, got "\
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||
if torch.version.hip:
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||
TILE_SIZE = min(512, TILE_SIZE)
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
|
||||
|
||||
reshape_and_cache_kernel_flash[grid](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
key_stride=key_stride,
|
||||
value_stride=value_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
@ -12,6 +12,8 @@ import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -154,6 +156,10 @@ class CUDAGraphWrapper:
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
|
||||
@ -509,8 +509,15 @@ class VllmConfig:
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
if envs.VLLM_USE_V1 and self.compilation_config.level \
|
||||
== CompilationLevel.PIECEWISE:
|
||||
# default to full and piecewise for most models
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
# pooling model does not support full cudagraphs
|
||||
if self.model_config is not None and \
|
||||
self.model_config.pooler_config is not None:
|
||||
self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
@ -638,11 +645,13 @@ class VllmConfig:
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend == "deepep_low_latency", \
|
||||
"Microbatching currently only supports the deepep_low_latency "\
|
||||
f"all2all backend. {a2a_backend} is not supported. To fix set "\
|
||||
"the VLLM_ALL2ALL_BACKEND environment variable to "\
|
||||
"deepep_low_latency and install the DeepEP kerenls."
|
||||
assert a2a_backend in \
|
||||
["deepep_low_latency", "deepep_high_throughput"], \
|
||||
"Microbatching currently only supports the deepep_low_latency and "\
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
|
||||
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||
"install the DeepEP kernels."
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
@ -685,6 +694,23 @@ class VllmConfig:
|
||||
# local attention.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
def has_blocked_weights():
|
||||
if self.quant_config is not None:
|
||||
if hasattr(self.quant_config, "weight_block_size"):
|
||||
return self.quant_config.weight_block_size is not None
|
||||
elif hasattr(self.quant_config, "has_blocked_weights"):
|
||||
return self.quant_config.has_blocked_weights()
|
||||
return False
|
||||
|
||||
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
|
||||
# On H100 the CUDA kernel is faster than
|
||||
# native implementation
|
||||
# https://github.com/vllm-project/vllm/issues/25094
|
||||
if has_blocked_weights():
|
||||
custom_ops = self.compilation_config.custom_ops
|
||||
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
|
||||
@ -228,15 +228,14 @@ class CompilationConfig:
|
||||
The mode of the cudagraph:
|
||||
|
||||
- NONE, no cudagraph capture.
|
||||
- PIECEWISE. (v1 default)
|
||||
- PIECEWISE.
|
||||
- FULL.
|
||||
- FULL_DECODE_ONLY.
|
||||
- FULL_AND_PIECEWISE.
|
||||
- FULL_AND_PIECEWISE. (v1 default)
|
||||
|
||||
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
|
||||
incompatible ops (i.e. some attention ops) outside the cudagraph
|
||||
for general flexibility.
|
||||
This is the default mode.
|
||||
|
||||
FULL mode: Capture full cudagraph for all batches. Can be good for small
|
||||
models or workloads with small prompts; not supported by many backends.
|
||||
@ -249,7 +248,7 @@ class CompilationConfig:
|
||||
|
||||
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
|
||||
piecewise cudagraph for prefill and mixed prefill-decode batches.
|
||||
This is like the most performant mode for most models.
|
||||
This is the most performant mode for most models and is the default.
|
||||
|
||||
Currently, the cudagraph mode is only used for the v1 engine.
|
||||
Note that the cudagraph logic is generally orthogonal to the
|
||||
|
||||
@ -1003,6 +1003,7 @@ class ModelConfig:
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
quant_method = quant_method if quant_method != "" else None
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
|
||||
@ -139,12 +139,18 @@ class ParallelConfig:
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable microbatching for the model executor."""
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
"""The threshold for microbatching. If the number of tokens in the
|
||||
request is greater than this threshold, microbatching will be used.
|
||||
Otherwise, the request will be processed in a single batch."""
|
||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||
If the number of tokens in the request is greater than this threshold,
|
||||
microbatching will be used. Otherwise, the request will be processed in a
|
||||
single batch."""
|
||||
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
||||
"""The threshold for dual batch overlap for batches that contain one or more
|
||||
prefills. If the number of tokens in the request is greater than this
|
||||
threshold, microbatching will be used. Otherwise, the request will be
|
||||
processed in a single batch."""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -200,12 +201,12 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
num_rdma_bytes = 0
|
||||
@ -230,13 +231,18 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
# It is dangerous to set num sms outside this function. num_sms is not
|
||||
# a part of the hash-key that identifies this object. If we are in a
|
||||
# situation where we make objects with different num_sms, the hash key
|
||||
# in get_or_create must be updated.
|
||||
handle.set_num_sms(self.num_sms)
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
import deep_ep
|
||||
|
||||
# Right now the buffers are sized for only what the kernels were
|
||||
# created with. So we can only reduce the number of SMS used
|
||||
# but not increase it.
|
||||
if num_sms > self.num_sms:
|
||||
num_sms = self.num_sms
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
@ -265,7 +271,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
import deep_ep
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
@ -291,3 +297,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer)
|
||||
return handle
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return 0
|
||||
@ -10,8 +10,9 @@ import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
@ -56,6 +57,30 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
}
|
||||
}
|
||||
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
"min_world_size": 4,
|
||||
"thresholds": {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
},
|
||||
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int,
|
||||
input_tensor: torch.Tensor) -> bool:
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled)
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
return False
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
return (world_size
|
||||
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
|
||||
|
||||
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
|
||||
@ -60,6 +60,12 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -7,6 +7,12 @@ import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
should_nccl_symm_mem_allreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
register_nccl_symmetric_ops)
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -24,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "tp" not in unique_name:
|
||||
# only tp uses custom allreduce
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
use_torch_symm_mem = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import (
|
||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
@ -53,11 +62,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
||||
if use_torch_symm_mem and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
@ -107,6 +118,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# since currently we perform copy input -> symm_input -> out-of-place AR
|
||||
# return symm_output, we don't need to check if input is symmetric
|
||||
if self.pynccl_comm is not None and \
|
||||
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
|
||||
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
# always try quick reduce first, then custom allreduce,
|
||||
# and then pynccl. (quick reduce just for ROCM MI3*)
|
||||
qr_comm = self.qr_comm
|
||||
|
||||
@ -17,6 +17,39 @@ from vllm.utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_NCCL_SYMM_OPS_REGISTERED = False
|
||||
|
||||
|
||||
def register_nccl_symmetric_ops(pynccl_comm):
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
nccl_symm_mem_context)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global _NCCL_SYMM_OPS_REGISTERED
|
||||
if _NCCL_SYMM_OPS_REGISTERED:
|
||||
return
|
||||
_NCCL_SYMM_OPS_REGISTERED = True
|
||||
|
||||
def all_reduce_symmetric_with_copy_impl(
|
||||
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with nccl_symm_mem_context(pynccl_comm):
|
||||
symm_input = torch.empty_like(input_tensor)
|
||||
symm_output = torch.empty_like(input_tensor)
|
||||
symm_input.copy_(input_tensor)
|
||||
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
|
||||
return symm_output
|
||||
|
||||
def all_reduce_symmetric_with_copy_fake(
|
||||
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce_symmetric_with_copy",
|
||||
op_func=all_reduce_symmetric_with_copy_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||
)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
@ -67,6 +100,7 @@ class PyNcclCommunicator:
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
@ -109,6 +143,7 @@ class PyNcclCommunicator:
|
||||
|
||||
def all_reduce(self,
|
||||
in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor = None,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
@ -120,7 +155,8 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
if out_tensor is None:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
@ -288,3 +324,18 @@ class PyNcclCommunicator:
|
||||
|
||||
def group_end(self):
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
def register_comm_window(self, tensor: torch.Tensor):
|
||||
return self.nccl.ncclCommWindowRegister(
|
||||
self.comm,
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel() * tensor.element_size(),
|
||||
1,
|
||||
)
|
||||
|
||||
def register_comm_window_raw(self, ptr: int, size: int):
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
|
||||
size, 1)
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
186
vllm/distributed/device_communicators/pynccl_allocator.py
Normal file
186
vllm/distributed/device_communicators/pynccl_allocator.py
Normal file
@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import atexit
|
||||
import contextlib
|
||||
import tempfile
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import find_nccl_include_paths
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
nccl_allocator_source = """
|
||||
#include <nccl.h>
|
||||
extern "C" {
|
||||
|
||||
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
||||
void* ptr;
|
||||
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||
return ptr;
|
||||
|
||||
}
|
||||
|
||||
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
||||
ncclResult_t err = ncclMemFree(ptr);
|
||||
}
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
_allocator = None
|
||||
_allocator_wrapper = None
|
||||
_mem_pool = None
|
||||
_registered_base_addrs = set()
|
||||
_graph_pool_id = None
|
||||
_nccl_allocator_failed_to_compile = False
|
||||
_cached_pool_snapshot = None
|
||||
|
||||
|
||||
def is_symmetric_memory_enabled():
|
||||
global _nccl_allocator_failed_to_compile
|
||||
return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile
|
||||
|
||||
|
||||
def is_symmetric_memory_tensor(tensor: torch.Tensor):
|
||||
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
|
||||
return False
|
||||
for segment in _cached_pool_snapshot:
|
||||
for block in segment["blocks"]:
|
||||
if block["address"] == tensor.untyped_storage().data_ptr():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_graph_pool_id(graph_pool_id):
|
||||
global _graph_pool_id
|
||||
_graph_pool_id = graph_pool_id
|
||||
|
||||
|
||||
def compile_nccl_allocator():
|
||||
global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile
|
||||
if not current_platform.is_cuda():
|
||||
_nccl_allocator_failed_to_compile = True
|
||||
return
|
||||
try:
|
||||
out_dir = tempfile.gettempdir()
|
||||
nccl_allocator_libname = "nccl_allocator"
|
||||
nccl_include_paths = find_nccl_include_paths()
|
||||
load_inline(
|
||||
name=nccl_allocator_libname,
|
||||
cpp_sources=nccl_allocator_source,
|
||||
with_cuda=True,
|
||||
extra_ldflags=["-lnccl"],
|
||||
verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG",
|
||||
is_python_module=False,
|
||||
build_directory=out_dir,
|
||||
extra_include_paths=nccl_include_paths,
|
||||
)
|
||||
_allocator_wrapper = CUDAPluggableAllocator(
|
||||
f"{out_dir}/{nccl_allocator_libname}.so",
|
||||
"nccl_alloc_plug",
|
||||
"nccl_free_plug",
|
||||
)
|
||||
_allocator = _allocator_wrapper.allocator()
|
||||
except Exception as e:
|
||||
_nccl_allocator_failed_to_compile = True
|
||||
logger.warning(
|
||||
"Failed to compile NCCL memory allocator. "
|
||||
"Symmetric memory will be disabled. "
|
||||
"This is expected if NCCL headers are not available. "
|
||||
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
|
||||
"containing the NCCL header. "
|
||||
"Error: %s", str(e))
|
||||
|
||||
|
||||
def get_nccl_mem_pool():
|
||||
global _mem_pool, _nccl_allocator_failed_to_compile
|
||||
if _mem_pool is None and not _nccl_allocator_failed_to_compile:
|
||||
compile_nccl_allocator()
|
||||
if _allocator is not None:
|
||||
_mem_pool = torch.cuda.MemPool(_allocator)
|
||||
return _mem_pool
|
||||
|
||||
|
||||
def _cleanup_nccl_mem_pool():
|
||||
global _mem_pool
|
||||
_mem_pool = None
|
||||
|
||||
|
||||
def _cleanup_nccl_allocator_wrapper():
|
||||
global _allocator_wrapper
|
||||
_allocator_wrapper = None
|
||||
|
||||
|
||||
atexit.register(_cleanup_nccl_mem_pool)
|
||||
atexit.register(_cleanup_nccl_allocator_wrapper)
|
||||
|
||||
|
||||
class nccl_symm_mem_context:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pynccl_comm: PyNcclCommunicator,
|
||||
disabled: bool = False,
|
||||
):
|
||||
self.disabled = (disabled or not is_symmetric_memory_enabled()
|
||||
or pynccl_comm.world_size == 1
|
||||
or not current_platform.is_cuda()
|
||||
or get_nccl_mem_pool() is None or version.parse(
|
||||
torch.__version__) < version.parse("2.8.0.a0"))
|
||||
if self.disabled:
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
self._mem_pool_ctx: contextlib.AbstractContextManager[
|
||||
Any] = contextlib.nullcontext()
|
||||
self.is_graph_capture = None
|
||||
self.device = None
|
||||
else:
|
||||
self.pynccl_comm = pynccl_comm
|
||||
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
||||
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
||||
self.device = torch.cuda.current_device()
|
||||
|
||||
def __enter__(self):
|
||||
if self.disabled:
|
||||
return self
|
||||
assert (
|
||||
self.pynccl_comm
|
||||
is not None), "Symmetric memory requires pynccl to be initalized"
|
||||
assert (
|
||||
self.pynccl_comm.nccl_version >= 22703
|
||||
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||
if self.is_graph_capture:
|
||||
assert (
|
||||
_graph_pool_id
|
||||
is not None), "graph_pool_id is not set under graph capture"
|
||||
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||
self._mem_pool_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.disabled:
|
||||
return
|
||||
global _cached_pool_snapshot
|
||||
global _registered_base_addrs
|
||||
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||
_pool = get_nccl_mem_pool()
|
||||
assert _pool is not None
|
||||
_cached_pool_snapshot = _pool.snapshot()
|
||||
assert self.pynccl_comm is not None
|
||||
for segment in _cached_pool_snapshot:
|
||||
if segment["address"] not in _registered_base_addrs:
|
||||
self.pynccl_comm.register_comm_window_raw(
|
||||
segment["address"], segment["total_size"])
|
||||
_registered_base_addrs.add(segment["address"])
|
||||
if self.is_graph_capture:
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
||||
self.device, _graph_pool_id)
|
||||
@ -41,6 +41,7 @@ logger = init_logger(__name__)
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
ncclWindow_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
@ -222,6 +223,24 @@ class NCCLLibrary:
|
||||
Function("ncclGroupStart", ncclResult_t, []),
|
||||
# ncclResult_t ncclGroupEnd();
|
||||
Function("ncclGroupEnd", ncclResult_t, []),
|
||||
# ncclResult_t ncclCommWindowRegister(
|
||||
# ncclComm_t comm, void* buff, size_t size,
|
||||
# ncclWindow_t* win, int winFlags);
|
||||
Function(
|
||||
"ncclCommWindowRegister",
|
||||
ncclResult_t,
|
||||
[
|
||||
ncclComm_t,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ctypes.POINTER(ncclWindow_t),
|
||||
ctypes.c_int,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclCommWindowDeregister(
|
||||
# ncclComm_t comm, ncclWindow_t win);
|
||||
Function("ncclCommWindowDeregister", ncclResult_t,
|
||||
[ncclComm_t, ncclWindow_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
@ -271,10 +290,14 @@ class NCCLLibrary:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
def ncclGetRawVersion(self) -> int:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903
|
||||
return version.value
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version_str = str(self.ncclGetRawVersion())
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
@ -375,6 +398,17 @@ class NCCLLibrary:
|
||||
def ncclGroupEnd(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||
|
||||
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
|
||||
size: int, win_flags: int) -> ncclWindow_t:
|
||||
window = ncclWindow_t()
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
|
||||
comm, buff, size, ctypes.byref(window), win_flags))
|
||||
return window
|
||||
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t,
|
||||
window: ncclWindow_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
|
||||
@ -330,6 +330,8 @@ class EngineArgs:
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
dbo_decode_token_threshold: int = \
|
||||
ParallelConfig.dbo_decode_token_threshold
|
||||
dbo_prefill_token_threshold: int = \
|
||||
ParallelConfig.dbo_prefill_token_threshold
|
||||
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||
@ -698,6 +700,9 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
**parallel_kwargs["dbo_decode_token_threshold"])
|
||||
parallel_group.add_argument(
|
||||
"--dbo-prefill-token-threshold",
|
||||
**parallel_kwargs["dbo_prefill_token_threshold"])
|
||||
parallel_group.add_argument("--enable-eplb",
|
||||
**parallel_kwargs["enable_eplb"])
|
||||
parallel_group.add_argument("--eplb-config",
|
||||
@ -1316,6 +1321,7 @@ class EngineArgs:
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_dbo=self.enable_dbo,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||
enable_eplb=self.enable_eplb,
|
||||
eplb_config=self.eplb_config,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
|
||||
@ -317,7 +317,8 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
)
|
||||
output_items.append(response_item)
|
||||
elif recipient is not None and (recipient.startswith("python")
|
||||
or recipient.startswith("browser")):
|
||||
or recipient.startswith("browser")
|
||||
or recipient.startswith("container")):
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
|
||||
26
vllm/envs.py
26
vllm/envs.py
@ -182,15 +182,19 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
|
||||
VLLM_DBO_COMM_SMS: int = 20
|
||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1366,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
|
||||
# Whether to use pytorch symmetric memory for allreduce
|
||||
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
||||
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
|
||||
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
|
||||
|
||||
# Allows vllm to find tuned config under customized folder
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
@ -1392,6 +1396,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||
|
||||
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
|
||||
"VLLM_DEEPEP_BUFFER_SIZE_MB":
|
||||
lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")),
|
||||
|
||||
# The number of SMs to allocate for communication kernels when running DBO
|
||||
# the rest of the SMs on the device will be allocated to compute
|
||||
"VLLM_DBO_COMM_SMS":
|
||||
lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
|
||||
|
||||
# Valid values are container,code_interpreter,web_search_preview
|
||||
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
||||
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
||||
@ -1399,6 +1412,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
["container",
|
||||
"code_interpreter",
|
||||
"web_search_preview"]),
|
||||
|
||||
# Flag to enable NCCL symmetric memory allocation and registration
|
||||
"VLLM_USE_NCCL_SYMM_MEM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
||||
|
||||
# NCCL header path
|
||||
"VLLM_NCCL_INCLUDE_PATH":
|
||||
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
|
||||
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@ -24,11 +24,12 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.input_size = self.base_layer.input_size
|
||||
# Ensure tp_size and tp_rank consistency with the base_layer.
|
||||
self.tp_size = self.base_layer.tp_size
|
||||
self.tp_rank = self.base_layer.tp_rank
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
|
||||
|
||||
self.output_slices: tuple[int, ...]
|
||||
self.tp_size: int
|
||||
self.output_size: int
|
||||
self.n_slices: int
|
||||
|
||||
|
||||
@ -8,9 +8,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@ -85,7 +83,6 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
# inconsistent when TP is greater than 1.
|
||||
self.is_merged_col_linear = type(
|
||||
base_layer) is MergedColumnParallelLinear
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size = self.base_layer.output_size_per_partition
|
||||
# There is only one LoRA layer
|
||||
self.n_slices = 1
|
||||
@ -97,22 +94,20 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
# Applicable to cases where the base_layer is
|
||||
# MergedColumnParallelLinear.
|
||||
if self.is_merged_col_linear:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_size // 2
|
||||
offset = lora_b.shape[0] // 2
|
||||
|
||||
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
|
||||
left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) *
|
||||
shard_size, :]
|
||||
right_weight = lora_b[offset + tp_rank * shard_size:offset +
|
||||
(tp_rank + 1) * shard_size, :]
|
||||
right_weight = lora_b[offset + self.tp_rank * shard_size:offset +
|
||||
(self.tp_rank + 1) * shard_size, :]
|
||||
lora_b = torch.cat([left_weight, right_weight], dim=0)
|
||||
# Applicable to cases where the base_layer is
|
||||
# ColumnParallelLinear.
|
||||
else:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[start_idx:end_idx, :]
|
||||
return lora_b
|
||||
|
||||
@ -120,10 +115,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
# TODO: Fix the slicing logic of bias.
|
||||
if bias is None:
|
||||
return bias
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
bias = bias[start_idx:end_idx]
|
||||
return bias
|
||||
|
||||
@ -144,7 +138,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply(input_, bias)
|
||||
if self.base_layer.gather_output:
|
||||
if self.base_layer.gather_output and self.tp_size > 1:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
@ -185,8 +179,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
QKVParallelLinear]) -> None:
|
||||
super().__init__(base_layer)
|
||||
# There are two LoRA layers
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
|
||||
# we need to divide it by the tp_size to get correct slices size
|
||||
output_sizes = self.base_layer.output_sizes
|
||||
@ -341,9 +333,9 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1), :]
|
||||
@ -397,8 +389,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
super().__init__(base_layer)
|
||||
# There are three LoRA layer.
|
||||
self.n_slices = len(self.base_layer.output_sizes)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
@ -461,9 +451,8 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
# Therefore, the sharding of `lora_a` only needs to correspond with the
|
||||
# gather operation.
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
lora_a = lora_a[start_idx:start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
@ -547,9 +536,8 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
lora_a = lora_a[start_idx:start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
||||
super().__init__(base_layer, )
|
||||
# To ensure interface compatibility, set to 1 always.
|
||||
self.tp_size = 1
|
||||
self.output_size = self.base_layer.output_size
|
||||
self.n_slices = 1
|
||||
|
||||
|
||||
@ -8,9 +8,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
from vllm.distributed import (split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
@ -25,12 +23,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# reset input_size
|
||||
self.input_size = self.base_layer.input_size_per_partition
|
||||
self.output_size = self.base_layer.output_size
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
# There is only one LoRA layer.
|
||||
self.n_slices = 1
|
||||
|
||||
@ -68,12 +63,12 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size)
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply(input_parallel)
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
if self.base_layer.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
@ -154,8 +149,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
buffer, x, self.lora_a_stacked, 1.0)
|
||||
if not current_platform.can_update_inplace():
|
||||
buffer = shrunk_buffer
|
||||
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
if self.tp_size>1:
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
|
||||
@ -48,11 +48,11 @@ class LoRALayerWeights:
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
return self.lora_a.shape[0]
|
||||
return self.lora_a.shape[1]
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self.lora_b.shape[1]
|
||||
return self.lora_b.shape[0]
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
|
||||
@ -12,6 +12,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm,
|
||||
dbo_switch_to_compute, dbo_switch_to_compute_sync,
|
||||
dbo_yield_and_switch_from_comm_to_compute,
|
||||
dbo_yield_and_switch_from_compute_to_comm)
|
||||
|
||||
|
||||
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
@ -46,9 +51,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.async_prepare = True
|
||||
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
# requires. Under DBO microbatching we must track one handle per
|
||||
# micro-batch to avoid races between threads.
|
||||
self.handles = [None, None]
|
||||
|
||||
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
||||
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
||||
@ -89,6 +94,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
has_scales = token_scales is not None
|
||||
|
||||
# We yield before launching the dispatch kernel since the dispatch
|
||||
# kernel will block the CPU so we want to queue up all the compute
|
||||
# for the other ubatch before the dispatch kernel starts.
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
|
||||
(num_tokens_per_rank, num_tokens_per_rdma_rank,
|
||||
dispatch_expert_num_tokens, is_token_in_rank,
|
||||
event) = self.buffer.get_dispatch_layout(
|
||||
@ -104,7 +114,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
(
|
||||
token_data, expert_topk_ids, expert_topk_weights,
|
||||
expert_num_tokens_per_expert_list, self.handle, event
|
||||
expert_num_tokens_per_expert_list, handle, event
|
||||
) = self.buffer.dispatch(
|
||||
x=token_data,
|
||||
handle=None,
|
||||
@ -119,9 +129,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_alignment=1,
|
||||
config=self._get_dispatch_config(),
|
||||
previous_event=None,
|
||||
async_finish=self.async_prepare,
|
||||
async_finish=self.async_prepare and not dbo_enabled(),
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
# record the handle for this ubatch
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
self.handles[a2a_idx] = handle
|
||||
|
||||
dbo_switch_to_compute_sync()
|
||||
|
||||
return lambda: self._receiver(
|
||||
event,
|
||||
has_scales,
|
||||
@ -146,7 +162,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
if self.async_prepare:
|
||||
if event.event is not None:
|
||||
event.current_stream_wait()
|
||||
|
||||
if has_scales:
|
||||
@ -207,7 +223,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
) -> mk.ReceiverType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
@ -233,14 +249,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1q_scale = None
|
||||
a1_post_scale = quant_config.a1_scale
|
||||
|
||||
return (lambda *args: None,
|
||||
self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config))
|
||||
return self._do_dispatch(tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts,
|
||||
a1_scale=a1_post_scale,
|
||||
quant_config=quant_config)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@ -252,10 +267,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
|
||||
num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts,
|
||||
expert_map, apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
def _finalize(
|
||||
@ -269,7 +283,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
|
||||
assert self.handle is not None
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
# fused_expert_output can have 0 tokens - This happens when none of the
|
||||
# tokens from the all2all reach this EP rank.
|
||||
@ -283,25 +299,35 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
combined_x, _, event = self.buffer.combine(
|
||||
x=fused_expert_output,
|
||||
handle=self.handle,
|
||||
handle=handle,
|
||||
topk_weights=None,
|
||||
config=self._get_combine_config(),
|
||||
previous_event=None,
|
||||
async_finish=do_async,
|
||||
async_finish=do_async and not dbo_enabled(),
|
||||
allocate_on_comm_stream=False)
|
||||
|
||||
dbo_switch_to_compute()
|
||||
|
||||
if do_async:
|
||||
|
||||
def _receiver():
|
||||
event.current_stream_wait()
|
||||
if event.event is not None:
|
||||
event.current_stream_wait()
|
||||
dbo_switch_to_comm()
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
|
||||
return lambda: _receiver()
|
||||
# TODO(lucas): refactor the modular kernel so this will be
|
||||
# handled there
|
||||
dbo_yield_and_switch_from_comm_to_compute()
|
||||
|
||||
return _receiver
|
||||
else:
|
||||
# TODO(lucas): support this case with the refactored modular kernel
|
||||
assert not dbo_enabled()
|
||||
# Respect inplace outputs.
|
||||
output.copy_(combined_x, non_blocking=True)
|
||||
return None
|
||||
|
||||
@ -206,7 +206,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
do_async: bool,
|
||||
) -> Optional[Callable]:
|
||||
) -> tuple[Callable, Callable]:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
@ -233,7 +233,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return_recv_hook=do_recv_hook,
|
||||
out=output)
|
||||
|
||||
return recv_hook
|
||||
return recv_hook, lambda: None
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
@ -243,8 +243,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
recv_hook = self._finalize(
|
||||
) -> tuple[Callable, Callable]:
|
||||
return self._finalize(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
@ -253,8 +253,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
weight_and_reduce_impl,
|
||||
do_async=True,
|
||||
)
|
||||
assert recv_hook is not None
|
||||
return recv_hook
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
|
||||
@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
||||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
|
||||
#
|
||||
@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, ReceiverType]:
|
||||
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `prepare`, e.g.
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`prepare`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
e.g.
|
||||
|
||||
receiver = obj.prepare_async(...)
|
||||
ret = obj.prepare_async(...)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
|
||||
if hook is not None:
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||
|
||||
is equivalent to:
|
||||
@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
) -> Union[tuple[Callable, Callable], Callable]:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `finalize`, e.g.
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`finalize`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
receiver = obj.finalize_async(output, ...)
|
||||
ret = obj.finalize_async(output, ...)
|
||||
... output not valid yet ...
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
receiver()
|
||||
... output valid here ...
|
||||
|
||||
@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
fused_out_buffer = SharedResizableBuffer()
|
||||
workspace13_buffer = SharedResizableBuffer()
|
||||
workspace2_buffer = SharedResizableBuffer()
|
||||
|
||||
class SharedBuffers:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.fused_out = SharedResizableBuffer()
|
||||
self.workspace13 = SharedResizableBuffer()
|
||||
self.workspace2 = SharedResizableBuffer()
|
||||
|
||||
# Persistent buffers that are shared across `FusedMoEModularKernel`
|
||||
# instances (layers), to save memory and allocattions.
|
||||
#
|
||||
# We have two sets of buffers to support dual batch overlap (DBO) where each
|
||||
# microbatch (ubatch) should use its own set of buffers to avoid
|
||||
# cross-ubatch contimination.
|
||||
# NOTE that memory is lazily allocated for these buffers, meaning that if
|
||||
# DBO isn't being used, the second SharedBuffers will be empty.
|
||||
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace13 = buffers.workspace13.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = buffers.workspace2.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
fused_out = buffers.fused_out.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
hook()
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = prepare_ret \
|
||||
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
recv_hook = self.prepare_finalize.finalize_async(
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
assert recv_hook is not None
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
recv_hook()
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = finalize_ret \
|
||||
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
|
||||
@ -644,6 +644,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
def has_blocked_weights(self) -> bool:
|
||||
for scheme in self.target_scheme_map.values():
|
||||
weight_quant = scheme.get("weights")
|
||||
if (weight_quant is not None
|
||||
and weight_quant.strategy == QuantizationStrategy.BLOCK):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
|
||||
@ -13,6 +13,7 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
@ -31,6 +32,9 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
select_nvfp4_gemm_impl)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
|
||||
requant_weight_ue8m0_inplace)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
@ -45,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -505,10 +510,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
if not (per_tensor or per_channel):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, we require per tensor "
|
||||
"or channelwise, dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
assert self.weight_quant.dynamic is not None
|
||||
else:
|
||||
self.weight_block_size = None
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales and per_channel:
|
||||
@ -519,7 +526,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
and not self.block_quant)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
@ -531,8 +539,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# cutlass path
|
||||
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
|
||||
self.weight_quant, self.input_quant)
|
||||
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
|
||||
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
|
||||
self.use_cutlass = not self.block_quant and (
|
||||
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
|
||||
or self.is_fp8_w8a8_sm100)
|
||||
self.disable_expert_map = False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@ -547,6 +556,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.weight_block_size[0],
|
||||
self.weight_block_size[1],
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size_per_partition % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1
|
||||
and intermediate_size_per_partition % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
@ -602,6 +636,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
2 *
|
||||
((intermediate_size_per_partition + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, (hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
@ -706,6 +761,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
del layer.w2_input_scale
|
||||
|
||||
if self.use_cutlass:
|
||||
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
|
||||
device = layer.w13_weight.device
|
||||
# ab_strides1 and c_strides2 are the same
|
||||
self.ab_strides1_c_strides2 = torch.full(
|
||||
@ -724,6 +780,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if is_deep_gemm_e8m0_used() and self.block_quant:
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w13_weight.data,
|
||||
layer.w13_weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.w2_weight.data,
|
||||
layer.w2_weight_scale.data,
|
||||
block_sz,
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale):
|
||||
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale)
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale):
|
||||
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
@ -777,9 +856,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
return experts
|
||||
|
||||
# triton path
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||
|
||||
@ -790,14 +870,16 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonExperts(
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts(%s)", self.__class__.__name__)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
logger.debug("TritonOrDeepGemmExperts(%s)",
|
||||
self.__class__.__name__)
|
||||
return TritonOrDeepGemmExperts(self.moe_quant_config,
|
||||
allow_deep_gemm=True)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
@ -816,6 +898,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_channel_quant,
|
||||
block_shape=layer.weight_block_size,
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
||||
@ -11,7 +11,7 @@ from torch.nn import Parameter
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
|
||||
@ -41,16 +41,30 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
if is_static_input_scheme else GroupShape.PER_TOKEN
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
if self.weight_block_size is not None:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR \
|
||||
if is_static_input_scheme else GroupShape.PER_TOKEN
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
@ -141,13 +155,14 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if layer.weight_block_size is not None:
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
if self.weight_block_size is not None:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
|
||||
@ -43,7 +43,7 @@ def prepare_block_fp8_matmul_inputs(
|
||||
return M, N, K, C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_deepgemm(
|
||||
def w8a8_deepgemm_block_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -59,7 +59,7 @@ def w8a8_block_fp8_matmul_deepgemm(
|
||||
return C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_deepgemm_fake(
|
||||
def w8a8_deepgemm_block_scaled_mm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -73,9 +73,9 @@ def w8a8_block_fp8_matmul_deepgemm_fake(
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="w8a8_block_fp8_matmul_deepgemm",
|
||||
op_func=w8a8_block_fp8_matmul_deepgemm,
|
||||
op_name="w8a8_deepgemm_block_scaled_mm",
|
||||
op_func=w8a8_deepgemm_block_scaled_mm,
|
||||
mutates_args=[],
|
||||
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
|
||||
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -31,12 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_block_linear, check_aiter_fp8_linear_support,
|
||||
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor,
|
||||
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||
validate_fp8_block_shape)
|
||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||
get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
|
||||
requant_weight_ue8m0_inplace, validate_fp8_block_shape)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
@ -64,12 +64,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@ -240,15 +234,28 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
if self.weight_block_size:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -397,12 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias)
|
||||
|
||||
if self.block_quant:
|
||||
return apply_fp8_block_linear(
|
||||
layer,
|
||||
assert self.weight_block_size is not None
|
||||
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported)
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
@ -660,10 +670,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = \
|
||||
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
|
||||
|
||||
@ -811,10 +821,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
# Ensure column-major TMA alignment expected by DeepGEMM.
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w13_weight_scale_inv)
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv)
|
||||
|
||||
|
||||
@ -27,11 +27,14 @@ class QuantFP8(CustomOp):
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None,
|
||||
column_major_scales: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
static: bool,
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: Optional[int] = None,
|
||||
column_major_scales: bool = False,
|
||||
use_ue8m0: Optional[bool] = None, # for Torch compile
|
||||
):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
|
||||
@ -46,6 +49,7 @@ class QuantFP8(CustomOp):
|
||||
self.group_shape = group_shape
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
@ -70,7 +74,8 @@ class QuantFP8(CustomOp):
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
column_major_scales=self.column_major_scales,
|
||||
dtype=_FP8_DTYPE)
|
||||
dtype=_FP8_DTYPE,
|
||||
use_ue8m0=self.use_ue8m0)
|
||||
|
||||
assert (scale is not None) == self.static
|
||||
assert scale_ub is None or (not self.static and self.group_shape
|
||||
@ -137,7 +142,10 @@ class QuantFP8(CustomOp):
|
||||
|
||||
x_grouped = x.view(-1, num_groups, self.group_size)
|
||||
absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float()
|
||||
scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
scales_raw = absmax / _FP8_MAX
|
||||
if self.use_ue8m0:
|
||||
scales_raw = torch.exp2(torch.ceil(torch.log2(scales_raw)))
|
||||
scales = (scales_raw).clamp(min=_FP8_MIN_SCALING_FACTOR)
|
||||
|
||||
x_scaled = x_grouped / scales
|
||||
x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
|
||||
@ -151,6 +159,6 @@ class QuantFP8(CustomOp):
|
||||
scales = scales.reshape(orig_shape[:-1] + (num_groups, ))
|
||||
|
||||
if self.column_major_scales:
|
||||
scales = scales.transpose(-2, -1).contiguous()
|
||||
scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2)
|
||||
|
||||
return x_quant, scales
|
||||
|
||||
@ -13,8 +13,9 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast)
|
||||
GroupShape, group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
@ -24,6 +25,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -35,6 +37,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
# current_platform.is_device_capability() is not supported by Torch compiler.
|
||||
def cutlass_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@ -42,15 +46,17 @@ def cutlass_scaled_mm(
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
is_hopper: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if is_hopper is None:
|
||||
is_hopper = current_platform.is_device_capability(90)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None
|
||||
and current_platform.is_device_capability(90) else Bs.T)
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
@ -98,122 +104,189 @@ if current_platform.is_rocm():
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
|
||||
|
||||
def dispatch_w8a8_blockscale_func(
|
||||
use_cutlass: bool, use_aiter_and_is_supported: bool
|
||||
) -> Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
list[int],
|
||||
torch.dtype,
|
||||
], torch.Tensor]:
|
||||
if use_cutlass:
|
||||
return cutlass_scaled_mm
|
||||
if (use_aiter_and_is_supported):
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
|
||||
return w8a8_block_fp8_matmul
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
def _w8a8_triton_block_scaled_mm_func(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale,
|
||||
block_size, output_dtype)
|
||||
|
||||
|
||||
def _w8a8_triton_block_scaled_mm_fake(
|
||||
qx: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((qx.size(0), weight.size(0)),
|
||||
dtype=output_dtype,
|
||||
device=qx.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
mutates_args=[],
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
dispatch_key="CUDA",
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
class W8A8BlockFp8LinearOp:
|
||||
"""
|
||||
This class executes a Blocked FP8 linear layer using cutlass if supported
|
||||
and torch.scaled_mm otherwise.
|
||||
"""
|
||||
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
|
||||
def __init__(
|
||||
self,
|
||||
weight_group_shape: GroupShape,
|
||||
act_quant_group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
):
|
||||
self.weight_group_shape = weight_group_shape
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
|
||||
# Get the correct blockscale mul and input quant operations.
|
||||
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
|
||||
# to use deepgemm because we don't know the shape of weights (and
|
||||
# whether deepgemm supports it) at the init time.
|
||||
self.w8a8_blockscale_op, self.input_quant_op = \
|
||||
self._dispatch_w8a8_blockscale_op(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
self.deepgemm_input_quant_op = (QuantFP8(
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=is_deep_gemm_e8m0_used()) if self.is_deep_gemm_supported
|
||||
else None)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
)
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight,
|
||||
self.is_deep_gemm_supported):
|
||||
output = self._run_deepgemm(input, weight, weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def _run_deepgemm(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# ensure DeepGEMM-backed custom op is registered before use
|
||||
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
||||
|
||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, x_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_deepgemm_block_scaled_mm(
|
||||
q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=output_dtype)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
self.weight_group_shape,
|
||||
output_dtype=input_2d.dtype)
|
||||
|
||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||
cutlass_block_fp8_supported, use_aiter_and_is_supported)
|
||||
if cutlass_block_fp8_supported:
|
||||
num_pad = 0
|
||||
if current_platform.is_device_capability(90):
|
||||
# pad first dimension to be divisible by 4 due to
|
||||
# cutlass blockwise gemm limitation for hopper
|
||||
num_pad = 4 - (input_2d.shape[0] % 4)
|
||||
if num_pad > 0:
|
||||
input_2d = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, num_pad),
|
||||
"constant", 0)
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
if num_pad > 0:
|
||||
output = output[:-num_pad]
|
||||
else:
|
||||
if use_aiter_and_is_supported:
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
def _run_cutlass(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.input_quant_op is not None
|
||||
if self.is_hopper:
|
||||
# We pad unconditionally (even if shape is already divisible by 4)
|
||||
# to support dynamic shape for input_2d.shape[0] in torch.compile
|
||||
x = torch.nn.functional.pad(input_2d,
|
||||
(0, 0, 0, -input_2d.shape[0] % 4))
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False)
|
||||
x = input_2d
|
||||
|
||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||
block_size, input.dtype)
|
||||
q_input, x_scale = self.input_quant_op(x)
|
||||
output = cutlass_scaled_mm(q_input, weight, x_scale, weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype, self.is_hopper)
|
||||
output = output[0:input_2d.shape[0], ...]
|
||||
return output
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
def _run_aiter(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
q_input, x_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
|
||||
input_2d.dtype)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.input_quant_op is not None
|
||||
q_input, x_scale = self.input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
q_input, weight, x_scale, weight_scale, self.weight_group_shape,
|
||||
input_2d.dtype)
|
||||
|
||||
def apply_w8a8_block_fp8_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
direct_register_custom_op(
|
||||
op_name="apply_w8a8_block_fp8_linear",
|
||||
op_func=apply_w8a8_block_fp8_linear,
|
||||
mutates_args=[],
|
||||
fake_impl=apply_w8a8_block_fp8_linear_fake,
|
||||
)
|
||||
def _dispatch_w8a8_blockscale_op(
|
||||
self,
|
||||
use_cutlass: bool,
|
||||
use_aiter_and_is_supported: bool,
|
||||
) -> tuple[Callable[[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
], torch.Tensor], Optional[QuantFP8]]:
|
||||
if use_cutlass:
|
||||
return self._run_cutlass, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=False))
|
||||
if use_aiter_and_is_supported:
|
||||
return self._run_aiter, None
|
||||
return self._run_triton, (QuantFP8(False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False))
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
@ -465,7 +538,7 @@ def per_token_group_quant_fp8(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
def _w8a8_triton_block_scaled_mm(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
@ -590,7 +663,7 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
return None
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
def w8a8_triton_block_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -650,7 +723,7 @@ def w8a8_block_fp8_matmul(
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
_w8a8_triton_block_scaled_mm[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
@ -997,20 +1070,7 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
cutlass_block_fp8_supported: bool,
|
||||
use_aiter_and_is_supported: bool) -> torch.Tensor:
|
||||
"""Apply block-wise FP8 linear operation."""
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||
input=input,
|
||||
weight=layer.weight,
|
||||
block_size=layer.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=use_aiter_and_is_supported,
|
||||
)
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
||||
@ -178,10 +178,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \
|
||||
qinput.shape[0] == 1 and \
|
||||
qinput.shape[1] % 16 == 0 and \
|
||||
((bias is None) or (bias.dtype == out_dtype)) :
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
current_platform.get_cu_count(), bias)
|
||||
else:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
|
||||
@ -100,7 +100,7 @@ def rocm_unquantized_gemm_impl(
|
||||
k = weight.shape[1]
|
||||
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
|
||||
x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and k % 8 == 0 and bias is None)
|
||||
and k % 8 == 0)
|
||||
|
||||
if use_skinny is not True:
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
@ -111,9 +111,9 @@ def rocm_unquantized_gemm_impl(
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
if m > 8 and 0 < n <= 4:
|
||||
out = ops.wvSplitK(weight, x_view, cu_count)
|
||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||
out = ops.LLMM1(weight, x_view, 4)
|
||||
return out.view(*x.shape[:-1], weight.shape[0])
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
@ -266,24 +266,24 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
if structured_outputs_config.reasoning_parser == "":
|
||||
structured_outputs_config.reasoning_parser = "openai_gptoss"
|
||||
|
||||
# Increase the max capture size from 512 to 1024 for performance.
|
||||
# Increase the max capture size from 512 to 992 for performance.
|
||||
# NOTE(woosuk): This will increase the number of CUDA graphs
|
||||
# from 67 to 83.
|
||||
# from 67 to 81.
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if len(scheduler_config.cuda_graph_sizes) == 1:
|
||||
max_capture_size = scheduler_config.cuda_graph_sizes[0]
|
||||
# FIXME(woosuk): When using full cuda graph with FA3, the max
|
||||
# supported size is 992.
|
||||
if max_capture_size < 1024:
|
||||
if max_capture_size < 992:
|
||||
cuda_graph_sizes = [1, 2, 4]
|
||||
# Step size 8 for small batch sizes
|
||||
cuda_graph_sizes += [i for i in range(8, 256, 8)]
|
||||
# Step size 16 for larger batch sizes
|
||||
cuda_graph_sizes += [i for i in range(256, 1025, 16)]
|
||||
cuda_graph_sizes += [i for i in range(256, 993, 16)]
|
||||
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
|
||||
logger.info(
|
||||
"Overriding max cuda graph capture size to "
|
||||
"%d for performance.", 1024)
|
||||
"%d for performance.", 992)
|
||||
|
||||
|
||||
class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@ -134,6 +134,11 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
# Ensure draft_vocab_size is set
|
||||
# default to the base vocab size when absent
|
||||
if getattr(self.config, "draft_vocab_size", None) is None:
|
||||
base_vocab_size = getattr(self.config, "vocab_size", None)
|
||||
self.config.draft_vocab_size = base_vocab_size
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
|
||||
@ -203,6 +203,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
# Ensure draft_vocab_size is set
|
||||
# default to the base vocab size when absent
|
||||
if getattr(self.config, "draft_vocab_size", None) is None:
|
||||
base_vocab_size = getattr(self.config, "vocab_size", None)
|
||||
self.config.draft_vocab_size = base_vocab_size
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
|
||||
@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module(
|
||||
"""
|
||||
assert isinstance(m, FusedMoE)
|
||||
w13 = m.w13_weight
|
||||
w13_s = m.w13_weight_scale_inv
|
||||
w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale)
|
||||
w2 = m.w2_weight
|
||||
w2_s = m.w2_weight_scale_inv
|
||||
w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale)
|
||||
num_topk = m.top_k
|
||||
|
||||
assert isinstance(w13, torch.Tensor)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
@ -59,6 +60,19 @@ class StructuredOutputsParams:
|
||||
f"but multiple are specified: {self.__dict__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidedDecodingParams(StructuredOutputsParams):
|
||||
|
||||
def __post_init__(self):
|
||||
warnings.warn(
|
||||
"GuidedDecodingParams is deprecated. This will be removed in "
|
||||
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||
"StructuredOutputsParams instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
@ -179,6 +193,8 @@ class SamplingParams(
|
||||
# Fields used to construct logits processors
|
||||
structured_outputs: Optional[StructuredOutputsParams] = None
|
||||
"""Parameters for configuring structured outputs."""
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
"""Deprecated alias for structured_outputs."""
|
||||
logit_bias: Optional[dict[int, float]] = None
|
||||
"""If provided, the engine will construct a logits processor that applies
|
||||
these logit biases."""
|
||||
@ -227,6 +243,7 @@ class SamplingParams(
|
||||
ge=-1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
structured_outputs: Optional[StructuredOutputsParams] = None,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||
allowed_token_ids: Optional[list[int]] = None,
|
||||
extra_args: Optional[dict[str, Any]] = None,
|
||||
@ -238,6 +255,15 @@ class SamplingParams(
|
||||
int(token): min(100.0, max(-100.0, bias))
|
||||
for token, bias in logit_bias.items()
|
||||
}
|
||||
if guided_decoding is not None:
|
||||
warnings.warn(
|
||||
"guided_decoding is deprecated. This will be removed in "
|
||||
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||
"structured_outputs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
structured_outputs = guided_decoding
|
||||
guided_decoding = None
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
@ -334,6 +360,16 @@ class SamplingParams(
|
||||
# eos_token_id is added to this by the engine
|
||||
self._all_stop_token_ids.update(self.stop_token_ids)
|
||||
|
||||
if self.guided_decoding is not None:
|
||||
warnings.warn(
|
||||
"guided_decoding is deprecated. This will be removed in "
|
||||
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
|
||||
"structured_outputs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
self.structured_outputs = self.guided_decoding
|
||||
self.guided_decoding = None
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(f"n must be an int, but is of "
|
||||
|
||||
@ -1383,6 +1383,38 @@ def find_nccl_library() -> str:
|
||||
return so_file
|
||||
|
||||
|
||||
def find_nccl_include_paths() -> Optional[list[str]]:
|
||||
"""
|
||||
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
|
||||
environment variable, or we find the library file brought by
|
||||
nvidia-nccl-cuXX. load_inline by default uses
|
||||
torch.utils.cpp_extension.include_paths
|
||||
"""
|
||||
paths: list[str] = []
|
||||
inc = envs.VLLM_NCCL_INCLUDE_PATH
|
||||
if inc and os.path.isdir(inc):
|
||||
paths.append(inc)
|
||||
|
||||
try:
|
||||
import importlib.util
|
||||
spec = importlib.util.find_spec("nvidia.nccl")
|
||||
if spec and getattr(spec, "submodule_search_locations", None):
|
||||
for loc in spec.submodule_search_locations:
|
||||
inc_dir = os.path.join(loc, "include")
|
||||
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
|
||||
paths.append(inc_dir)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
seen = set()
|
||||
out: list[str] = []
|
||||
for p in paths:
|
||||
if p and p not in seen:
|
||||
out.append(p)
|
||||
seen.add(p)
|
||||
return out or None
|
||||
|
||||
|
||||
prev_set_stream = torch.cuda.set_stream
|
||||
|
||||
_current_stream_tls = threading.local()
|
||||
|
||||
@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any, Callable, NoReturn
|
||||
from typing import Any, Callable, NoReturn, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -172,9 +172,13 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
|
||||
weight: torch.Tensor):
|
||||
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
|
||||
def should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype: torch.dtype,
|
||||
weight: torch.Tensor,
|
||||
supports_deep_gemm: Optional[bool] = None):
|
||||
if supports_deep_gemm is None:
|
||||
supports_deep_gemm = is_deep_gemm_supported()
|
||||
return (supports_deep_gemm and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -291,7 +293,13 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@ -303,8 +311,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale_float == 1.0, \
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
|
||||
@ -107,19 +107,57 @@ def _make_metadata_with_slice(
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
assert not ubatch_slice.is_empty(), (
|
||||
f"Ubatch slice {ubatch_slice} is empty")
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
start_locs = attn_metadata.query_start_loc_cpu
|
||||
first_req = request_slice.start
|
||||
first_tok = token_slice.start
|
||||
last_req = request_slice.stop - 1
|
||||
last_tok = token_slice.stop - 1
|
||||
|
||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \
|
||||
"Token slice start outside of first request"
|
||||
assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \
|
||||
"Token slice end outside of last request"
|
||||
|
||||
# If the "middle" request has tokens in both ubatches, we have to split it.
|
||||
# If ubatch_slice is the first ubatch then we will be splitting the last
|
||||
# request. If it's the second microbatch, then we will be splitting the
|
||||
# first request
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
||||
request_slice)
|
||||
|
||||
assert len(query_start_loc) >= 2, (
|
||||
f"query_start_loc must have at least 2 elements, "
|
||||
f"got {len(query_start_loc)}")
|
||||
query_start_loc_cpu = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc_cpu, request_slice)
|
||||
|
||||
if splits_first_request:
|
||||
tokens_skipped = first_tok - start_locs[first_req]
|
||||
query_start_loc[1:] -= tokens_skipped
|
||||
query_start_loc_cpu[1:] -= tokens_skipped
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
# Make sure we don't modify the seq_lens tensors
|
||||
# (not cudagraph compatible)
|
||||
seq_lens = seq_lens.clone()
|
||||
seq_lens_cpu = seq_lens_cpu.clone()
|
||||
seq_lens[-1] -= tokens_skipped
|
||||
seq_lens_cpu[-1] -= tokens_skipped
|
||||
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||
request_slice]
|
||||
@ -167,6 +205,7 @@ def split_attn_metadata(
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(
|
||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -696,7 +735,6 @@ def split_decodes_and_prefills(
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
||||
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
|
||||
@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadataBuilder
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
@ -30,7 +31,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -78,6 +78,8 @@ class EagleProposer:
|
||||
self.is_multimodal_model = vllm_config.model_config \
|
||||
.is_multimodal_model
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
@ -118,7 +120,7 @@ class EagleProposer:
|
||||
with_numpy=True)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
|
||||
self.allowed_attn_types: tuple[type, ...]
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
|
||||
@ -191,11 +193,12 @@ class EagleProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
# Select the correct attention metadata builders for EAGLE layers.
|
||||
# Get the attention metadata builders once and reuse for later.
|
||||
builder = (self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None else
|
||||
self.attn_metadata_builder)
|
||||
attn_metadata = builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
@ -329,11 +332,9 @@ class EagleProposer:
|
||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
attn_metadata = builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
@ -538,9 +539,8 @@ class EagleProposer:
|
||||
hidden_states: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
tree_attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
assert isinstance(tree_attn_metadata_builder,
|
||||
TreeAttentionMetadataBuilder)
|
||||
|
||||
@ -854,10 +854,24 @@ class EagleProposer:
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.vllm_config.speculative_config.method != "eagle3" and \
|
||||
hasattr(target_language_model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
if self.vllm_config.speculative_config.method != "eagle3":
|
||||
if hasattr(target_language_model, "lm_head"):
|
||||
logger.info(
|
||||
"Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
else:
|
||||
if (hasattr(self.model, "lm_head")
|
||||
and hasattr(target_language_model, "lm_head")
|
||||
and self.model.lm_head.weight.shape
|
||||
== target_language_model.lm_head.weight.shape):
|
||||
logger.info("Assuming the EAGLE head shares the same lm_head"
|
||||
" with the target model.")
|
||||
del self.model.lm_head
|
||||
self.model.lm_head = target_language_model.lm_head
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's lm_head will be loaded separately"
|
||||
" from the target model.")
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
@ -880,6 +894,31 @@ class EagleProposer:
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
def _get_attention_metadata_builder(
|
||||
self) -> list[AttentionMetadataBuilder]:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
builder = None
|
||||
chosen_layer = self.attn_layer_names[0]
|
||||
|
||||
for kv_cache_group in self.runner.attn_groups:
|
||||
for attn_group in kv_cache_group:
|
||||
if chosen_layer in attn_group.layer_names:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
break
|
||||
if builder is not None:
|
||||
break
|
||||
|
||||
assert builder is not None, (
|
||||
"Failed to find attention metadata builder for EAGLE layers.")
|
||||
return builder
|
||||
|
||||
def validate_same_kv_cache_group(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
@ -96,7 +96,8 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
||||
from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds,
|
||||
ubatch_split)
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
@ -1032,7 +1033,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
||||
num_tokens_unpadded)
|
||||
ubatch_slices, num_tokens_after_padding = \
|
||||
ubatch_split(max_num_scheduled_tokens,
|
||||
ubatch_split(num_scheduled_tokens,
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
self.vllm_config)
|
||||
@ -1176,9 +1177,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
spec_decode_common_attn_metadata is None:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
if (self.speculative_config
|
||||
and spec_decode_common_attn_metadata is None):
|
||||
if isinstance(self.drafter, EagleProposer):
|
||||
if (self.drafter.attn_layer_names[0]
|
||||
in kv_cache_group_spec.layer_names):
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
else:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
@ -1206,7 +1212,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
ubatch_slices, common_attn_metadata)
|
||||
for ubid, common_attn_metadata in enumerate(
|
||||
common_attn_metadata_list):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = (attn_group.get_metadata_builder(
|
||||
ubatch_id=ubid).build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
@ -2182,9 +2187,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) = self._preprocess(scheduler_output, intermediate_tensors,
|
||||
ubatch_slices, num_tokens_after_padding)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_input_tokens = num_input_tokens // 2
|
||||
|
||||
uniform_decode = (max_query_len
|
||||
== self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens
|
||||
@ -2194,6 +2196,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor)
|
||||
|
||||
# This is currently to get around the assert in the DPMetadata
|
||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||
if ubatch_slices is not None:
|
||||
num_input_tokens = ubatch_slices[0].num_tokens
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with (set_forward_context(
|
||||
@ -2360,7 +2367,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
aux_hidden_states: Optional[list[torch.Tensor]],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> Union[list[list[int]], torch.Tensor]:
|
||||
@ -2380,6 +2387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
assert spec_decode_metadata is not None
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
sampled_token_ids):
|
||||
@ -2430,6 +2438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions.gpu[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert aux_hidden_states is not None
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
@ -2455,6 +2464,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
target_positions = self.positions.gpu[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert aux_hidden_states is not None
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
@ -2821,7 +2831,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
allow_microbatching: bool = False,
|
||||
allow_microbatching: bool = True,
|
||||
skip_eplb: bool = False,
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
@ -2847,32 +2857,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
ubatch_enabled = self.parallel_config.enable_dbo
|
||||
num_tokens_across_dp = None
|
||||
num_pad = 0
|
||||
should_ubatch = False
|
||||
if ubatch_enabled:
|
||||
should_ubatch = num_tokens >= \
|
||||
self.parallel_config.dbo_decode_token_threshold and \
|
||||
allow_microbatching
|
||||
|
||||
(should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch(
|
||||
num_tokens, num_tokens, should_ubatch, self.vllm_config)
|
||||
|
||||
# Currently the dummy run should only be ubatching during
|
||||
# cuda graph capture, meaning all DP ranks should already
|
||||
# have the same batch size
|
||||
if num_tokens_across_dp is not None:
|
||||
assert int(num_tokens_across_dp[0]) == num_tokens // 2
|
||||
|
||||
assert cudagraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
if not should_ubatch:
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens += num_pad
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
# cudagraph_mode.separate_routine(). This means that we are using
|
||||
# different graphs and/or modes for mixed prefill-decode batches vs.
|
||||
@ -2888,10 +2876,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# for GQA/MQA.
|
||||
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
||||
num_tokens
|
||||
if allow_microbatching:
|
||||
assert self.uniform_decode_query_len == 1
|
||||
assert uniform_decode is True
|
||||
assert max_query_len == 1
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
@ -2916,7 +2900,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert not create_mixed_batch
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
"Do not capture num_reqs > max_num_reqs for uniform batch"
|
||||
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
|
||||
f"{max_num_reqs} for uniform batch. Num tokens: " \
|
||||
f"{num_tokens}, max_query_len: {max_query_len}"
|
||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||
if num_tokens % max_query_len != 0:
|
||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||||
@ -2930,20 +2916,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||
|
||||
ubatch_slices = None
|
||||
num_tokens_after_padding = None
|
||||
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if should_ubatch:
|
||||
# We only support decode-only cudagraphs
|
||||
assert num_reqs == num_tokens
|
||||
assert num_tokens % 2 == 0
|
||||
ubatch_slices = [
|
||||
UBatchSlice(slice(0, num_reqs // 2), slice(0,
|
||||
num_tokens // 2)),
|
||||
UBatchSlice(slice(num_reqs // 2, num_reqs),
|
||||
slice(num_tokens // 2, num_tokens))
|
||||
]
|
||||
if self.parallel_config.enable_dbo and allow_microbatching:
|
||||
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
self.vllm_config,
|
||||
)
|
||||
|
||||
# If we failed to microbatch, currently need to resynchronize
|
||||
# TODO(lucas,sage): we should be able to avoid this second sync by
|
||||
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
|
||||
# a single `coordinate_batch_across_dp` function.
|
||||
if num_tokens_after_padding is None:
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
num_tokens_after_padding = num_tokens + num_pad
|
||||
else:
|
||||
num_tokens_across_dp = num_tokens_after_padding
|
||||
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
|
||||
|
||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||
|
||||
@ -2960,12 +2957,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# TODO(luka) better system for describing dummy batches
|
||||
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
|
||||
else:
|
||||
# Make sure max_model_len is used at the graph capture time.
|
||||
seq_lens = self.max_model_len
|
||||
seq_lens = max_query_len
|
||||
self.seq_lens.np[:num_reqs] = seq_lens
|
||||
self.seq_lens.np[num_reqs:] = 0
|
||||
self.seq_lens.copy_to_gpu()
|
||||
|
||||
cum_num_tokens, _ = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
@ -3060,7 +3061,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
@ -3395,56 +3396,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
desc="Capturing CUDA graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
cudagraph_runtime_mode.name))
|
||||
enable_dbo = self.parallel_config.enable_dbo
|
||||
# DBO Only supports running Full cudagraphs with uniform
|
||||
# decode lengths
|
||||
if enable_dbo and uniform_decode:
|
||||
for num_tokens in compilation_cases:
|
||||
# If the number of tokens is greater than the microbatching
|
||||
# threshold, don't generate a microbatched cudagraph
|
||||
if (num_tokens
|
||||
< self.parallel_config.dbo_decode_token_threshold):
|
||||
continue
|
||||
|
||||
# Warmup
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph and for uniform decode batches.
|
||||
capture_ubatched_graph = self.parallel_config.enable_dbo \
|
||||
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
|
||||
and uniform_decode \
|
||||
and check_ubatch_thresholds(
|
||||
config=self.vllm_config.parallel_config,
|
||||
num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
|
||||
# Currently we capture both microbatched and non-microbatched
|
||||
# graphs when capture_ubatched_graph is True, this is because
|
||||
# occasionally we will be forced out of microbatching due to other
|
||||
# DP ranks not microbatching (usually caused by an empty second
|
||||
# microbatch; once we resolve this, we can remove the
|
||||
# non-microbatched graph capture).
|
||||
allow_microbatching_options = [True, False] if \
|
||||
capture_ubatched_graph else [False]
|
||||
for allow_microbatching in allow_microbatching_options:
|
||||
for _ in range(
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=True,
|
||||
allow_microbatching=True,
|
||||
skip_eplb=True)
|
||||
|
||||
# Graph Capture
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True,
|
||||
allow_microbatching=True,
|
||||
skip_eplb=True)
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens in compilation_cases:
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.FULL)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self._dummy_run(num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
skip_eplb=True,
|
||||
remove_lora=False)
|
||||
self.maybe_remove_all_loras(self.lora_config)
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
@ -3500,24 +3496,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for (attn_backend,
|
||||
kv_cache_spec), layer_names in attn_backends_map.items():
|
||||
attn_metadata_builders = []
|
||||
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
attn_group = AttentionGroup.create_with_metadata_builders(
|
||||
attn_backend,
|
||||
layer_names,
|
||||
kv_cache_spec,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
))
|
||||
if self.parallel_config.enable_dbo:
|
||||
attn_metadata_builders.append(
|
||||
attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
))
|
||||
attn_group = AttentionGroup(attn_backend,
|
||||
attn_metadata_builders,
|
||||
layer_names, kv_cache_spec)
|
||||
num_metadata_builders=1
|
||||
if not self.parallel_config.enable_dbo else 2,
|
||||
)
|
||||
|
||||
attn_groups.append(attn_group)
|
||||
return attn_groups
|
||||
|
||||
@ -3562,6 +3550,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
CUDAGraphMode.FULL_DECODE_ONLY
|
||||
logger.warning(msg)
|
||||
|
||||
# check that if we are doing decode full-cudagraphs it is supported
|
||||
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and min_cg_support == AttentionCGSupport.NEVER):
|
||||
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
|
||||
f"with {min_cg_builder_name} backend (support: "
|
||||
f"{min_cg_support})")
|
||||
if (self.compilation_config.level == CompilationLevel.PIECEWISE and
|
||||
(self.compilation_config.splitting_ops_contain_attention()
|
||||
or self.compilation_config.use_inductor_graph_partition)):
|
||||
msg += "; setting cudagraph_mode=PIECEWISE because "\
|
||||
"attention is compiled piecewise"
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
msg += "; setting cudagraph_mode=NONE because "\
|
||||
"attention is not compiled piecewise"
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode = \
|
||||
CUDAGraphMode.NONE
|
||||
logger.warning(msg)
|
||||
|
||||
# check that if we are doing spec-decode + decode full-cudagraphs it is
|
||||
# supported
|
||||
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
|
||||
@ -1,25 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id)
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
@ -29,13 +34,55 @@ class UbatchMetadata:
|
||||
num_tokens: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
ubatch_metadata: UbatchMetadata
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
|
||||
class SMControlContextManager:
|
||||
|
||||
def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None],
|
||||
set_compute_sms: Callable[[int], None]):
|
||||
"""
|
||||
Context manager for controlling SM (Streaming Multiprocessor)
|
||||
allocation. Upon entering the context, it sets the number of SMs
|
||||
allocated for communication and computation to comm_sms and
|
||||
total_sms - comm_sms respectively. Upon exiting, it restores the
|
||||
allocation to use all available SMs (i.e. total_sms).
|
||||
|
||||
Args:
|
||||
comm_sms (int): The number of SMs to allocate for communication.
|
||||
(The remainder will be used for computation.)
|
||||
set_comm_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for communication.
|
||||
set_compute_sms (Callable[[int], None]):
|
||||
A function that sets the number of SMs for computation.
|
||||
"""
|
||||
|
||||
assert current_platform.is_cuda(), \
|
||||
"SM control is currently only supported on CUDA"
|
||||
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
total_sms = props.multi_processor_count
|
||||
|
||||
assert comm_sms < total_sms
|
||||
self.total_sms = total_sms
|
||||
self.compute_sms = total_sms - comm_sms
|
||||
self.comm_sms = comm_sms
|
||||
self.set_comm_sms = set_comm_sms
|
||||
self.set_compute_sms = set_compute_sms
|
||||
|
||||
def __enter__(self):
|
||||
self.set_comm_sms(self.comm_sms)
|
||||
self.set_compute_sms(self.compute_sms)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.set_comm_sms(self.total_sms)
|
||||
self.set_compute_sms(self.total_sms)
|
||||
|
||||
|
||||
class UBatchWrapper:
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
@ -56,6 +103,35 @@ class UBatchWrapper:
|
||||
runnable, vllm_config, runtime_mode=runtime_mode)
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
self.sm_control = self._create_sm_control_context(vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
all2all_manager = get_ep_group(
|
||||
).device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager.max_sms_used() is not None:
|
||||
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||
|
||||
if comm_sms > 0:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
set_compute_sms = lambda sms: None
|
||||
if has_deep_gemm() and comm_sms > 0:
|
||||
import deep_gemm as dg
|
||||
set_compute_sms = lambda sms: dg.set_num_sms(sms)
|
||||
|
||||
return SMControlContextManager(comm_sms=comm_sms,
|
||||
set_comm_sms=set_comm_sms,
|
||||
set_compute_sms=set_compute_sms)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
@ -132,6 +208,10 @@ class UBatchWrapper:
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
ubatch_metadata=ubatch_metadata,
|
||||
)
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||
stream=compute_stream,
|
||||
pool=self.graph_pool):
|
||||
@ -282,8 +362,8 @@ class UBatchWrapper:
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_tokens in self.cudagraphs:
|
||||
cudagraph_metadata = self.cudagraphs[num_tokens]
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
@ -300,4 +380,5 @@ class UBatchWrapper:
|
||||
dp_metadata=dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE)
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
|
||||
@ -386,11 +386,13 @@ class Worker(WorkerBase):
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
|
||||
f"requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` "
|
||||
f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
|
||||
f"into requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` "
|
||||
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{int(self.available_kv_cache_memory_bytes)} bytes.")
|
||||
f"{GiB(self.available_kv_cache_memory_bytes)} GiB.")
|
||||
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.forward_context import DPMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_up
|
||||
@ -29,6 +30,16 @@ def should_ubatch_with_num_tokens(
|
||||
dp_size, dp_rank)
|
||||
|
||||
|
||||
def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int,
|
||||
uniform_decode: bool) -> bool:
|
||||
if not config.enable_dbo:
|
||||
return False
|
||||
if uniform_decode:
|
||||
return num_tokens >= config.dbo_decode_token_threshold
|
||||
else:
|
||||
return num_tokens >= config.dbo_prefill_token_threshold
|
||||
|
||||
|
||||
def get_dp_padding_ubatch(
|
||||
num_tokens_unpadded: int, num_tokens_padded: int,
|
||||
should_attempt_ubatching: bool,
|
||||
@ -95,9 +106,37 @@ def get_dp_padding_ubatch(
|
||||
dtype=torch.int32)
|
||||
return should_ubatch, num_tokens_after_padding
|
||||
|
||||
def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \
|
||||
-> UBatchSlices:
|
||||
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
|
||||
# in cu_num_tokens directly (i.e. query_start_loc)
|
||||
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
|
||||
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
|
||||
|
||||
first_ubatch_token_slice = slice(0, split_point)
|
||||
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
|
||||
|
||||
# Determine request slices using exclusive stop semantics
|
||||
# First ubatch includes requests whose tokens overlap [0, split_point)
|
||||
first_ubatch_req_stop = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="left"))
|
||||
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
|
||||
|
||||
# Second ubatch starts at the request that contains the split_point
|
||||
# or the request starting exactly at split_point (if on boundary)
|
||||
second_ubatch_req_start = int(
|
||||
np.searchsorted(cu_num_tokens, split_point, side="right") - 1)
|
||||
second_ubatch_req_slice = slice(second_ubatch_req_start,
|
||||
len(cu_num_tokens) - 1)
|
||||
|
||||
return [
|
||||
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
|
||||
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice)
|
||||
]
|
||||
|
||||
|
||||
def ubatch_split(
|
||||
max_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_per_request: np.ndarray,
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
vllm_config: VllmConfig,
|
||||
@ -122,17 +161,20 @@ def ubatch_split(
|
||||
return (None, None)
|
||||
|
||||
# Check preconditions for microbatching
|
||||
should_attempt_ubatching = \
|
||||
parallel_config.enable_dbo and \
|
||||
num_tokens_unpadded >= \
|
||||
parallel_config.dbo_decode_token_threshold \
|
||||
and max_num_scheduled_tokens == 1
|
||||
should_attempt_ubatching = check_ubatch_thresholds(
|
||||
parallel_config,
|
||||
num_tokens_unpadded,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
# Don't microbatch unless every other DP worker is also microbatching
|
||||
num_tokens_after_padding = None
|
||||
(should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching,
|
||||
vllm_config)
|
||||
should_ubatch, num_tokens_after_padding = get_dp_padding_ubatch(
|
||||
num_tokens_unpadded,
|
||||
num_tokens_padded,
|
||||
should_attempt_ubatching,
|
||||
vllm_config,
|
||||
)
|
||||
|
||||
if not should_ubatch:
|
||||
return (None, None)
|
||||
|
||||
@ -141,15 +183,9 @@ def ubatch_split(
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item())
|
||||
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
||||
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
||||
num_tokens_unpadded)
|
||||
token_split_point = int(num_tokens_after_padding[0].item())
|
||||
|
||||
# Note there's an assumption here that there's 1 token per request
|
||||
ubatch_slices = [
|
||||
UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice),
|
||||
UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice)
|
||||
]
|
||||
ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request,
|
||||
token_split_point)
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
|
||||
@ -10,6 +10,14 @@ class UBatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.request_slice.start == self.request_slice.stop \
|
||||
or self.token_slice.start == self.token_slice.stop
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.token_slice.stop - self.token_slice.start
|
||||
|
||||
|
||||
UBatchSlices: TypeAlias = list[UBatchSlice]
|
||||
|
||||
|
||||
@ -51,8 +51,8 @@ class UBatchContext:
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we start on the compute stream
|
||||
assert current_stream() == self.compute_stream
|
||||
# Assume we want to start on the compute stream
|
||||
self.update_stream(self.compute_stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -62,17 +62,15 @@ class UBatchContext:
|
||||
self.maybe_run_recv_hook()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
self.current_stream = self.compute_stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
if current_stream() != self.current_stream:
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def _signal_comm_done(self):
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
@ -99,9 +97,20 @@ class UBatchContext:
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
def switch_to_comm(self):
|
||||
self.update_stream(self.comm_stream)
|
||||
|
||||
def switch_to_compute(self):
|
||||
self.update_stream(self.compute_stream)
|
||||
|
||||
def switch_to_comm_sync(self):
|
||||
self._signal_compute_done()
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
|
||||
def switch_to_compute_sync(self):
|
||||
self._signal_comm_done()
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
|
||||
def maybe_run_recv_hook(self):
|
||||
@ -112,8 +121,7 @@ class UBatchContext:
|
||||
def yield_(self):
|
||||
self.current_stream = current_stream()
|
||||
self._cpu_yield()
|
||||
if self.current_stream != current_stream():
|
||||
self.update_stream(self.current_stream)
|
||||
self.update_stream(self.current_stream)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
@ -153,15 +161,20 @@ def _register_ubatch_function(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||
UBatchContext.maybe_run_recv_hook)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_compute_to_comm)
|
||||
dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function(
|
||||
UBatchContext.yield_and_switch_from_comm_to_compute)
|
||||
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
|
||||
dbo_maybe_run_recv_hook = _register_ubatch_function(
|
||||
UBatchContext.maybe_run_recv_hook)
|
||||
dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm)
|
||||
dbo_switch_to_compute = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute)
|
||||
dbo_switch_to_comm_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_comm_sync)
|
||||
dbo_switch_to_compute_sync = _register_ubatch_function(
|
||||
UBatchContext.switch_to_compute_sync)
|
||||
|
||||
|
||||
def dbo_register_recv_hook(recv_hook):
|
||||
|
||||
@ -130,15 +130,32 @@ class MultiModalBudget:
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
# When ubatching is enabled we will have a metadata builder for each ubatch
|
||||
# so that if they use internal persistant buffers for cudagraphs, and they
|
||||
# won't have to worry about conflicting with the other ubatches.
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
@staticmethod
|
||||
def create_with_metadata_builders(
|
||||
backend: type[AttentionBackend],
|
||||
layer_names: list[str],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
num_metadata_builders: int = 1,
|
||||
) -> 'AttentionGroup':
|
||||
metadata_builders = [
|
||||
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
for _ in range(num_metadata_builders)
|
||||
]
|
||||
return AttentionGroup(backend, metadata_builders, layer_names,
|
||||
kv_cache_spec)
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: Optional[int] = None
|
||||
) -> AttentionMetadataBuilder:
|
||||
if ubatch_id is None:
|
||||
return self.metadata_builders[0]
|
||||
ubatch_id: int = 0) -> AttentionMetadataBuilder:
|
||||
assert len(self.metadata_builders) > ubatch_id
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user