Merge branch 'main' into Add_support_for_openpangu_promoe_v2

Signed-off-by: yt0428 <51468697+yt0428@users.noreply.github.com>
This commit is contained in:
yt0428 2025-11-26 11:41:50 +08:00 committed by GitHub
commit 28169a6fce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 1685 additions and 441 deletions

View File

@ -903,11 +903,12 @@ steps:
- label: Transformers Nightly Models Test - label: Transformers Nightly Models Test
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
optional: true optional: true
soft_fail: true
commands: commands:
- pip install --upgrade git+https://github.com/huggingface/transformers - pip install --upgrade git+https://github.com/huggingface/transformers
- pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)' - pytest -v -s tests/models/test_initialization.py
- pytest -v -s tests/models/test_transformers.py - pytest -v -s tests/models/test_transformers.py
# - pytest -v -s tests/models/multimodal/processing/ - pytest -v -s tests/models/multimodal/processing/
- pytest -v -s tests/models/multimodal/test_mapping.py - pytest -v -s tests/models/multimodal/test_mapping.py
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl

View File

@ -13,7 +13,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Set up Python - name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0

View File

@ -12,7 +12,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v7 - uses: astral-sh/setup-uv@v7
with: with:

View File

@ -16,7 +16,7 @@ jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with: with:
python-version: "3.12" python-version: "3.12"

View File

@ -604,12 +604,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") "csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}") CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else() else()
message(STATUS "Not building NVFP4 as no compatible archs were found.") message(STATUS "Not building NVFP4 as no compatible archs were found.")

View File

@ -22,6 +22,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h" #include "cutlass/tensor_ref.h"
@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
} }
template <typename OutType> template <typename OutType>
void run_fp4_blockwise_scaled_group_mm( void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm(
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM"); "Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM // Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr()); auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream); status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
// NOTE: For SM120 it seems templating the output type is not supported and
// we need to hardcode the output type to bfloat16
using ElementC = cutlass::bfloat16_t;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm120;
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_128, _128, _128>;
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
LayoutD*, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto,
FusionOperation>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor c_strides1 =
torch::full({num_experts}, output.stride(0), options_int);
torch::Tensor a_strides1 =
torch::full({num_experts}, a.stride(0) * 2, options_int);
torch::Tensor b_strides1 =
torch::full({num_experts}, b.stride(1) * 2, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
expert_offsets, sf_offsets, problem_sizes, M, N, K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides1.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides1.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides1.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides1.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array =
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
fusion_args.beta = 0.0f;
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) {
run_fp4_blockwise_scaled_group_mm_sm120(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
return;
}
#endif
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
if (version_num >= 100 && version_num < 120) {
run_fp4_blockwise_scaled_group_mm_sm100<OutType>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120");
}
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif #endif
@ -374,7 +583,8 @@ void cutlass_fp4_group_mm(
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation // Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
@ -408,6 +618,14 @@ void cutlass_fp4_group_mm(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K); expert_offsets, sf_offsets, M, N, K);
} else { } else {
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type());
}
#endif
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>( run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K); expert_offsets, sf_offsets, M, N, K);
@ -416,8 +634,8 @@ void cutlass_fp4_group_mm(
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must " "No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA " "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
"12.8 or above."); "and CUDA 12.8 or above.");
#endif #endif
} }

View File

@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int; constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte; constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm100a( void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,

View File

@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input_sf); torch::Tensor const& input_sf);
#endif #endif
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
void scaled_fp4_experts_quant_sm100a( (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
@ -54,8 +55,9 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) { torch::Tensor const& output_scale_offset_by_experts) {
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
return scaled_fp4_experts_quant_sm100a( (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts, output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts); output_scale_offset_by_experts);
#endif #endif

View File

@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
#endif #endif
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ #if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \ (defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120 (defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data(
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation, problem_sizes2, input_permutation,
output_permutation, num_experts, n, k, output_permutation, num_experts, n, k,
@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data(
false, false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ", "CUDA device capability: ",
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90, 100, or 120");
} }
void get_cutlass_moe_mm_problem_sizes( void get_cutlass_moe_mm_problem_sizes(
@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes(
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k, problem_sizes2, num_experts, n, k,
blockscale_offsets); blockscale_offsets);
@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes(
false, false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ", "kernel for CUDA device capability: ",
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90, 100, or 120");
} }
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens, problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k); num_local_experts, padded_m, n, k);
@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false, false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ", "for CUDA device capability: ",
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90, 100, or 120");
} }
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,

View File

@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4MoeMethod] - [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]

View File

@ -133,7 +133,7 @@ def main(args):
tensor_parallel_size=args.tp, tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill, enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.9,
speculative_config=speculative_config, speculative_config=speculative_config,
disable_log_stats=False, disable_log_stats=False,
max_model_len=args.max_model_len, max_model_len=args.max_model_len,

View File

@ -55,7 +55,7 @@ class SillyModel(nn.Module):
def _run_simple_model( def _run_simple_model(
splitting_ops, splitting_ops,
use_inductor_graph_partition, use_inductor_graph_partition,
use_inductor, backend,
expected_num_piecewise_graphs_seen, expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen, expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations, expected_num_backend_compilations,
@ -64,7 +64,7 @@ def _run_simple_model(
vllm_config = VllmConfig( vllm_config = VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
use_inductor=use_inductor, backend=backend,
splitting_ops=splitting_ops, splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition, use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True, cudagraph_copy_inputs=True,
@ -124,14 +124,14 @@ def _run_simple_model(
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
@pytest.mark.parametrize("use_inductor", [True, False]) @pytest.mark.parametrize("backend", ["inductor", "eager"])
@torch.inference_mode() @torch.inference_mode()
@create_new_process_for_each_test("spawn") @create_new_process_for_each_test("spawn")
def test_simple_piecewise_compile(use_inductor): def test_simple_piecewise_compile(backend):
_run_simple_model( _run_simple_model(
splitting_ops=["silly::attention"], splitting_ops=["silly::attention"],
use_inductor_graph_partition=False, use_inductor_graph_partition=False,
use_inductor=use_inductor, backend=backend,
# 2 * num_layers + 1 # 2 * num_layers + 1
expected_num_piecewise_graphs_seen=5, expected_num_piecewise_graphs_seen=5,
# 1 + num_layers # 1 + num_layers
@ -155,7 +155,7 @@ def test_simple_inductor_graph_partition(monkeypatch):
_run_simple_model( _run_simple_model(
splitting_ops=["silly::attention"], splitting_ops=["silly::attention"],
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
use_inductor=True, backend="inductor",
# Since not splitting at fx graph level # Since not splitting at fx graph level
expected_num_piecewise_graphs_seen=1, expected_num_piecewise_graphs_seen=1,
# Since not splitting at fx graph level # Since not splitting at fx graph level

View File

@ -249,14 +249,13 @@ def test_compilation_config():
args = parser.parse_args( args = parser.parse_args(
[ [
"-O", "-O",
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
'"use_inductor": false}',
] ]
) )
assert ( assert (
args.compilation_config.mode == 3 args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and not args.compilation_config.use_inductor and args.compilation_config.backend == "eager"
) )
# set to string form of a dict # set to string form of a dict
@ -264,13 +263,13 @@ def test_compilation_config():
[ [
"--compilation-config=" "--compilation-config="
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": true}', '"backend": "inductor"}',
] ]
) )
assert ( assert (
args.compilation_config.mode == 3 args.compilation_config.mode == 3
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and args.compilation_config.use_inductor and args.compilation_config.backend == "inductor"
) )
@ -278,8 +277,9 @@ def test_prefix_cache_default():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([]) args = parser.parse_args([])
# should be None by default (depends on model).
engine_args = EngineArgs.from_cli_args(args=args) engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.enable_prefix_caching, "prefix caching should default to on." assert engine_args.enable_prefix_caching is None
# with flag to turn it on. # with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"]) args = parser.parse_args(["--enable-prefix-caching"])

View File

@ -2,6 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content, Content,
ResponseReasoningItem, ResponseReasoningItem,
@ -76,6 +79,18 @@ class TestResponsesUtils:
== 'Hmm, the user has just started with a simple "Hello,"' == 'Hmm, the user has just started with a simple "Hello,"'
) )
tool_call_output = ResponseFunctionToolCallOutputItem(
id="temp_id",
type="function_call_output",
call_id="temp",
output="1234",
status="completed",
)
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
assert formatted_item["role"] == "tool"
assert formatted_item["content"] == "1234"
assert formatted_item["tool_call_id"] == "temp"
item = ResponseReasoningItem( item = ResponseReasoningItem(
id="lol", id="lol",
summary=[], summary=[],

View File

@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.weight_utils import download_gguf
class TestGGUFDownload:
"""Test GGUF model downloading functionality."""
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_single_file(self, mock_download):
"""Test downloading a single GGUF file."""
# Setup mock
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
# Mock glob to return a single file
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
)
result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
# Verify download_weights_from_hf was called with correct patterns
mock_download.assert_called_once_with(
model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
cache_dir=None,
allow_patterns=[
"*-IQ1_S.gguf",
"*-IQ1_S-*.gguf",
"*/*-IQ1_S.gguf",
"*/*-IQ1_S-*.gguf",
],
revision=None,
ignore_patterns=None,
)
# Verify result is the file path, not folder
assert result == f"{mock_folder}/model-IQ1_S.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_sharded_files(self, mock_download):
"""Test downloading sharded GGUF files."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
# Mock glob to return sharded files
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[
f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
]
if "Q2_K" in pattern
else []
)
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
# Should return the first file after sorting
assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_subdir(self, mock_download):
"""Test downloading GGUF files from subdirectory."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
if "Q2_K" in pattern or "**/*.gguf" in pattern
else []
)
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
@patch("glob.glob", return_value=[])
def test_download_gguf_no_files_found(self, mock_glob, mock_download):
"""Test error when no GGUF files are found."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
class TestGGUFModelLoader:
"""Test GGUFModelLoader class methods."""
@patch("os.path.isfile", return_value=True)
def test_prepare_weights_local_file(self, mock_isfile):
"""Test _prepare_weights with local file."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "/path/to/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/path/to/model.gguf"
mock_isfile.assert_called_once_with("/path/to/model.gguf")
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
"""Test _prepare_weights with HTTPS URL."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_hf_download.return_value = "/downloaded/model.gguf"
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "https://huggingface.co/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/downloaded/model.gguf"
mock_hf_download.assert_called_once_with(
url="https://huggingface.co/model.gguf"
)
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
"""Test _prepare_weights with repo_id/filename.gguf format."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_hf_download.return_value = "/downloaded/model.gguf"
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/downloaded/model.gguf"
mock_hf_download.assert_called_once_with(
repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
)
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
@patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
@patch("vllm.config.model.get_config")
@patch("vllm.config.model.is_gguf", return_value=True)
@patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_repo_quant_type(
self,
mock_isfile,
mock_download_gguf,
mock_is_gguf,
mock_get_config,
mock_file_exists,
mock_get_image_config,
):
"""Test _prepare_weights with repo_id:quant_type format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
mock_text_config = MockTextConfig()
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
model_config = ModelConfig(
model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
)
result = loader._prepare_weights(model_config)
# The actual result will be the downloaded file path from mock
assert result == "/downloaded/model-IQ1_S.gguf"
mock_download_gguf.assert_called_once_with(
"unsloth/Qwen3-0.6B-GGUF",
"IQ1_S",
cache_dir=None,
revision=None,
ignore_patterns=["original/**/*"],
)
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
@patch("vllm.config.model.get_config")
@patch("vllm.config.model.is_gguf", return_value=False)
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_invalid_format(
self,
mock_isfile,
mock_check_gguf,
mock_is_gguf,
mock_get_config,
mock_get_image_config,
):
"""Test _prepare_weights with invalid format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
mock_text_config = MockTextConfig()
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
# Create ModelConfig with a valid repo_id to avoid validation errors
# Then test _prepare_weights with invalid format
model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
# Manually set model to invalid format after creation
model_config.model = "invalid-format"
with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
loader._prepare_weights(model_config)

View File

@ -1,11 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from unittest.mock import patch
import pytest
from vllm.transformers_utils.utils import ( from vllm.transformers_utils.utils import (
is_cloud_storage, is_cloud_storage,
is_gcs, is_gcs,
is_gguf,
is_remote_gguf,
is_s3, is_s3,
split_remote_gguf,
) )
@ -28,3 +34,143 @@ def test_is_cloud_storage():
assert is_cloud_storage("s3://model-path/path-to-model") assert is_cloud_storage("s3://model-path/path-to-model")
assert not is_cloud_storage("/unix/local/path") assert not is_cloud_storage("/unix/local/path")
assert not is_cloud_storage("nfs://nfs-fqdn.local") assert not is_cloud_storage("nfs://nfs-fqdn.local")
class TestIsRemoteGGUF:
"""Test is_remote_gguf utility function."""
def test_is_remote_gguf_with_colon_and_slash(self):
"""Test is_remote_gguf with repo_id:quant_type format."""
# Valid quant types
assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert is_remote_gguf("user/repo:Q2_K")
assert is_remote_gguf("repo/model:Q4_K")
assert is_remote_gguf("repo/model:Q8_0")
# Invalid quant types should return False
assert not is_remote_gguf("repo/model:quant")
assert not is_remote_gguf("repo/model:INVALID")
assert not is_remote_gguf("repo/model:invalid_type")
def test_is_remote_gguf_without_colon(self):
"""Test is_remote_gguf without colon."""
assert not is_remote_gguf("repo/model")
assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
def test_is_remote_gguf_without_slash(self):
"""Test is_remote_gguf without slash."""
assert not is_remote_gguf("model.gguf")
# Even with valid quant_type, no slash means not remote GGUF
assert not is_remote_gguf("model:IQ1_S")
assert not is_remote_gguf("model:quant")
def test_is_remote_gguf_local_path(self):
"""Test is_remote_gguf with local file path."""
assert not is_remote_gguf("/path/to/model.gguf")
assert not is_remote_gguf("./model.gguf")
def test_is_remote_gguf_with_path_object(self):
"""Test is_remote_gguf with Path object."""
assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
assert not is_remote_gguf(Path("repo/model"))
def test_is_remote_gguf_with_http_https(self):
"""Test is_remote_gguf with HTTP/HTTPS URLs."""
# HTTP/HTTPS URLs should return False even with valid quant_type
assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
assert not is_remote_gguf("http://repo/model:Q4_K")
assert not is_remote_gguf("https://repo/model:Q8_0")
def test_is_remote_gguf_with_cloud_storage(self):
"""Test is_remote_gguf with cloud storage paths."""
# Cloud storage paths should return False even with valid quant_type
assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
assert not is_remote_gguf("s3://repo/model:Q4_K")
assert not is_remote_gguf("gs://repo/model:Q8_0")
class TestSplitRemoteGGUF:
"""Test split_remote_gguf utility function."""
def test_split_remote_gguf_valid(self):
"""Test split_remote_gguf with valid repo_id:quant_type format."""
repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
assert quant_type == "IQ1_S"
repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
assert repo_id == "repo/model"
assert quant_type == "Q2_K"
def test_split_remote_gguf_with_path_object(self):
"""Test split_remote_gguf with Path object."""
repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
assert quant_type == "IQ1_S"
def test_split_remote_gguf_invalid(self):
"""Test split_remote_gguf with invalid format."""
# Invalid format (no colon) - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("repo/model")
# Invalid quant type - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("repo/model:INVALID_TYPE")
# HTTP URL - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("http://repo/model:IQ1_S")
# Cloud storage - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("s3://bucket/repo/model:Q2_K")
class TestIsGGUF:
"""Test is_gguf utility function."""
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
def test_is_gguf_with_local_file(self, mock_check_gguf):
"""Test is_gguf with local GGUF file."""
assert is_gguf("/path/to/model.gguf")
assert is_gguf("./model.gguf")
def test_is_gguf_with_remote_gguf(self):
"""Test is_gguf with remote GGUF format."""
# Valid remote GGUF format (repo_id:quant_type with valid quant_type)
assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert is_gguf("repo/model:Q2_K")
assert is_gguf("repo/model:Q4_K")
# Invalid quant_type should return False
assert not is_gguf("repo/model:quant")
assert not is_gguf("repo/model:INVALID")
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
def test_is_gguf_false(self, mock_check_gguf):
"""Test is_gguf returns False for non-GGUF models."""
assert not is_gguf("unsloth/Qwen3-0.6B")
assert not is_gguf("repo/model")
assert not is_gguf("model")
def test_is_gguf_edge_cases(self):
"""Test is_gguf with edge cases."""
# Empty string
assert not is_gguf("")
# Only colon, no slash (even with valid quant_type)
assert not is_gguf("model:IQ1_S")
# Only slash, no colon
assert not is_gguf("repo/model")
# HTTP/HTTPS URLs
assert not is_gguf("http://repo/model:IQ1_S")
assert not is_gguf("https://repo/model:Q2_K")
# Cloud storage
assert not is_gguf("s3://bucket/repo/model:IQ1_S")
assert not is_gguf("gs://bucket/repo/model:Q2_K")

View File

@ -166,7 +166,7 @@ def test_dict_args(parser):
"--hf-overrides.key2.key4", "--hf-overrides.key2.key4",
"val3", "val3",
# Test compile config and compilation mode # Test compile config and compilation mode
"-O.use_inductor=true", "-O.use_inductor_graph_partition=true",
"-O.backend", "-O.backend",
"custom", "custom",
"-O1", "-O1",
@ -219,7 +219,7 @@ def test_dict_args(parser):
} }
assert parsed_args.compilation_config == { assert parsed_args.compilation_config == {
"mode": 1, "mode": 1,
"use_inductor": True, "use_inductor_graph_partition": True,
"backend": "custom", "backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
} }

View File

@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
) )
# Test case 1: Requires additional lookahead tokens # Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_new_tokens=3, num_new_tokens=3,
@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks # Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
# required_blocks = ceil((3 + 2) /4) = 2 # required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks # Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2 # required_blocks = ceil((3 + 4) / 4) = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_new_tokens=3, num_new_tokens=3,
@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
), ),
], ],
) )
# different hidden size
# different hidden size but same type, use UniformTypeKVCacheSpecs
kv_cache_specs_hybrid = { kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=128), "layer_1": new_kv_cache_spec(head_size=128),
"layer_2": new_kv_cache_spec(head_size=64), "layer_2": new_kv_cache_spec(head_size=64),
@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
], ],
) )
# Different hidden size and different type, align by different block size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=32),
}
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
)[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
KVCacheGroupSpec(
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
),
],
)
# different hidden size that cannot be aligned by using different block size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=96),
}
with pytest.raises(NotImplementedError):
get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
)[0]
# Test num_gpu_blocks_override # Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16 vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_configs( kv_cache_config_override_blocks = get_kv_cache_configs(

View File

@ -134,6 +134,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21), make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
hash_fn = sha256 hash_fn = sha256
@ -416,6 +418,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# the default hash function is sha256 # the default hash function is sha256
hash_fn = sha256 hash_fn = sha256
@ -523,6 +526,7 @@ def test_decode():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
@ -585,6 +589,7 @@ def test_evict():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
last_token_id = 5 * 16 + 7 last_token_id = 5 * 16 + 7
@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2), make_kv_cache_config(16, 2),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Allocate 1 block and cache it. # Allocate 1 block and cache it.
@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3), make_kv_cache_config(block_size, 3),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Allocate a block and cache it. # Allocate a block and cache it.
@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5), make_kv_cache_config(block_size, 5),
max_model_len=8192, max_model_len=8192,
enable_caching=False, enable_caching=False,
hash_block_size=block_size,
) )
req1 = make_request( req1 = make_request(
@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool( block_pool = BlockPool(
num_gpu_blocks=5, num_gpu_blocks=5,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Req: # Req:
# Block 0: [0, 1, 2, 3] # Block 0: [0, 1, 2, 3]
@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups. This tests that blocks are cached correctly for different kv cache groups.
""" """
block_size = 4 block_size = 4
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
)
# Req: # Req:
# Block 0/4: [0, 1, 2, 3] # Block 0/4: [0, 1, 2, 3]
@ -921,6 +932,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Common prompt tokens (T is text tokens and P is image placeholder tokens) # Common prompt tokens (T is text tokens and P is image placeholder tokens)
@ -1020,6 +1032,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# 3 complete blocks and an incomplete block with 11 tokens. # 3 complete blocks and an incomplete block with 11 tokens.
@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... | # | Common-0 | Common-1 | Common-2 | ... |
@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
full_block_token_ids = [i for i in range(3) for _ in range(16)] full_block_token_ids = [i for i in range(3) for _ in range(16)]
@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
log_stats=False, # Disable logging stats log_stats=False, # Disable logging stats
) )
assert manager.prefix_cache_stats is None assert manager.prefix_cache_stats is None
@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block(): def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True) pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
enable_kv_cache_events=True, enable_kv_cache_events=True,
hash_block_size=block_size,
) )
num_tokens = block_size * blocks_to_cache num_tokens = block_size * blocks_to_cache
@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
enable_kv_cache_events=True, enable_kv_cache_events=True,
hash_block_size=block_size,
) )
# Test with LoRA request # Test with LoRA request
@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
use_eagle=True, use_eagle=True,
hash_block_size=block_size,
) )
# Request with 3 full blocks (48 tokens) # Request with 3 full blocks (48 tokens)
@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
use_eagle=True, use_eagle=True,
hash_block_size=block_size,
) )
# 2 full blocks + 5 tokens (non-divisible length) # 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5) token_ids = [0] * (2 * block_size + 5)
@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
use_eagle=True, use_eagle=True,
hash_block_size=block_size,
) )
# 2 full blocks + 5 tokens (non-divisible length) # 2 full blocks + 5 tokens (non-divisible length)
@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0 assert num_tokens == 0
def test_different_block_size():
block_size = 16
# full attention and sliding window attention layers have the same page size:
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size,
1,
1,
torch.float32,
sliding_window=2 * block_size,
),
),
],
)
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
common_token_ids = [i for i in range(10) for _ in range(block_size)]
req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert not computed_blocks.blocks[1]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
# Evict some blocks to make sliding window cache hit length 5*16
# But should return 4 * 16 because full attention cache hit length must be
# a multiple of 32
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
)
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 2
assert len(computed_blocks.blocks[1]) == 4
assert num_computed_tokens == 4 * 16
def test_block_lookup_cache_single_block_per_key(): def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap() cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0") key0 = BlockHashWithGroupId(b"hash0")

View File

@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
attention_chunk_size=4, attention_chunk_size=4,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_chunked_local_attention_manager( manager = get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool chunked_local_attention_spec, block_pool
) )
@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
block_pool=block_pool, block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec, kv_cache_spec=chunked_local_attention_spec,
use_eagle=False, use_eagle=False,
alignment_tokens=block_size,
)[0] )[0]
assert len(computed_blocks) == expect_length assert len(computed_blocks) == expect_length
@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
sliding_window=4, sliding_window=4,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length): def run_one_case(block_is_cached, expect_length):
@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
block_pool=block_pool, block_pool=block_pool,
kv_cache_spec=sliding_window_spec, kv_cache_spec=sliding_window_spec,
use_eagle=False, use_eagle=False,
alignment_tokens=block_size,
)[0] )[0]
assert len(computed_blocks) == expect_length assert len(computed_blocks) == expect_length
@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
attention_chunk_size=4, attention_chunk_size=4,
) )
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_chunked_local_attention_manager(attention_spec, block_pool) manager = get_chunked_local_attention_manager(attention_spec, block_pool)
@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
sliding_window=4, sliding_window=4,
) )
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_sliding_window_manager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
sliding_window=4, # Placeholder value, not related to test result sliding_window=4, # Placeholder value, not related to test result
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
attention_chunk_size=4, # Placeholder value, not related to test result attention_chunk_size=4, # Placeholder value, not related to test result
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_chunked_local_attention_manager(attention_spec, block_pool) manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [

View File

@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case. # Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50} spec_config_short = spec_config | {"max_model_len": 50}
test_sampling_params = [
dict(),
dict(logprobs=2),
]
# test_preemption, executor, async_scheduling, # test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking # spec_config, test_prefill_chunking
test_configs = [ test_configs = [
@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True), (True, "uni", True, spec_config_short, True),
] ]
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
@dynamo_config.patch(cache_size_limit=16) @dynamo_config.patch(cache_size_limit=16)

View File

@ -11,6 +11,7 @@ import pprint
import time import time
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy
from functools import partial from functools import partial
from typing import Any from typing import Any
@ -429,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.vllm_backend.compiler_manager.compile( self.vllm_backend.compiler_manager.compile(
submod, submod,
args, args,
self.compilation_config.inductor_compile_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
graph_index=index, graph_index=index,
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
@ -531,6 +532,9 @@ class VllmBackend:
sym_tensor_indices: list[int] sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor] input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager compiler_manager: CompilerManager
# Copy of CompilationConfig.inductor_compile_config +
# an entry for PostGradPassManager
inductor_config: dict[str, Any]
def __init__( def __init__(
self, self,
@ -561,25 +565,30 @@ class VllmBackend:
self.compilation_config self.compilation_config
) )
# Deepcopy the inductor config to detach the post-grad custom pass
# from CompilationConfig.
# We want to avoid PostGradPassManager in CompilationConfig because
# in future we need PostGradPassManager.uuid() to be executed
# only at compile time.
self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
def configure_post_pass(self): def configure_post_pass(self):
config = self.compilation_config
self.pass_manager.configure(self.vllm_config) self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass # Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager. # hook. If a pass for that hook exists, add it to the pass manager.
inductor_config = config.inductor_compile_config if self.pass_key in self.inductor_config:
if self.pass_key in inductor_config: if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
if isinstance(inductor_config[self.pass_key], PostGradPassManager): raise ValueError(
# PassManager already added to config, make sure it's correct "PostGradPassManager can not be kept in CompilationConfig."
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid() )
else: else:
# Config should automatically wrap all inductor passes # Config should automatically wrap all inductor passes
assert isinstance(inductor_config[self.pass_key], InductorPass) assert isinstance(self.inductor_config[self.pass_key], InductorPass)
self.pass_manager.add(inductor_config[self.pass_key]) self.pass_manager.add(self.inductor_config[self.pass_key])
inductor_config[self.pass_key] = self.pass_manager self.inductor_config[self.pass_key] = self.pass_manager
def __call__( def __call__(
self, graph: fx.GraphModule, example_inputs self, graph: fx.GraphModule, example_inputs
@ -638,9 +647,7 @@ class VllmBackend:
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled( disable_cache = not is_compile_cache_enabled(self.inductor_config)
self.compilation_config.inductor_compile_config
)
if disable_cache: if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect import inspect
import os import os
import pickle import pickle
@ -14,6 +13,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors from vllm.config.utils import hash_factors
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
try: try:
from torch._dynamo.aot_compile import SerializableCallable from torch._dynamo.aot_compile import SerializableCallable
@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these. # e.g. exec(). We can't actually check these.
continue continue
hash_content.append(content) hash_content.append(content)
return hashlib.md5( return safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False "\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest() ).hexdigest()

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import copy import copy
import hashlib
import os import os
from collections.abc import Callable from collections.abc import Callable
from contextlib import ExitStack from contextlib import ExitStack
@ -16,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors() factors = get_inductor_factors()
hash_str = hashlib.md5( hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
str(factors).encode(), usedforsecurity=False :10
).hexdigest()[:10] ]
return hash_str return hash_str
def initialize_cache( def initialize_cache(
@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors() factors = get_inductor_factors()
hash_str = hashlib.md5( hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
str(factors).encode(), usedforsecurity=False :10
).hexdigest()[:10] ]
return hash_str return hash_str
def initialize_cache( def initialize_cache(

View File

@ -107,7 +107,7 @@ class PiecewiseBackend:
entry.runnable = self.vllm_backend.compiler_manager.compile( entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args, args,
self.compilation_config.inductor_compile_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,

View File

@ -144,7 +144,7 @@ class CacheConfig:
kv_offloading_backend: KVOffloadingBackend | None = None kv_offloading_backend: KVOffloadingBackend | None = None
"""The backend to use for KV cache offloading. Supported backends include """The backend to use for KV cache offloading. Supported backends include
'native' (vLLM native CPU offloading), 'lmcache' This option must be used 'native' (vLLM native CPU offloading), 'lmcache' This option must be used
together with kv_offloading_size.""" together with kv_offloading_size."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -167,8 +167,6 @@ class CacheConfig:
"num_gpu_blocks_override", "num_gpu_blocks_override",
"enable_prefix_caching", "enable_prefix_caching",
"prefix_caching_hash_algo", "prefix_caching_hash_algo",
# `cpu_offload_gb` does not use `torch.compile` yet.
"cpu_offload_gb",
"cpu_kvcache_space_bytes", "cpu_kvcache_space_bytes",
"mamba_page_size_padded", "mamba_page_size_padded",
# Post-init/derived counters # Post-init/derived counters

View File

@ -264,7 +264,6 @@ class CompilationConfig:
- [`cudagraph_copy_inputs`] - [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs] [vllm.config.CompilationConfig.cudagraph_copy_inputs]
- Inductor compilation: - Inductor compilation:
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`inductor_compile_config`] - [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config] [vllm.config.CompilationConfig.inductor_compile_config]
@ -348,7 +347,7 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2 - 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True. disabled when running with Inductor: mode>=VLLM_COMPILE and backend="inductor".
Inductor generates (fused) Triton kernels for disabled custom ops.""" Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] | None = None splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation. """A list of ops to exclude from cudagraphs, used in piecewise compilation.
@ -374,24 +373,6 @@ class CompilationConfig:
Disabled by default until more models are supported/tested to work.""" Disabled by default until more models are supported/tested to work."""
# Inductor capture # Inductor capture
use_inductor: bool | None = None
"""
Whether to use inductor compilation.
This flag is deprecated and will be removed in the next release 0.12.0.
Please use the 'backend' option instead.
- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
- True: inductor compilation is used (custom_ops disabled by default).
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if mode<VLLM_COMPILE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
compile_sizes: list[int | str] | None = None compile_sizes: list[int | str] | None = None
"""Sizes to compile for inductor. In addition """Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to to integers, it also supports "cudagraph_capture_sizes" to
@ -759,14 +740,6 @@ class CompilationConfig:
f"Invalid backend for piecewise compilation: {self.backend}" f"Invalid backend for piecewise compilation: {self.backend}"
) )
if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be "
"removed in the next release (v0.12.0). "
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "": if self.backend == "":
self.backend = current_platform.get_compile_backend() self.backend = current_platform.get_compile_backend()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field from dataclasses import field
from typing import Any, Literal from typing import Any, Literal
@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@ -45,7 +45,7 @@ class DeviceConfig:
# the device/platform information will be summarized # the device/platform information will be summarized
# by torch/vllm automatically. # by torch/vllm automatically.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid import uuid
from dataclasses import field from dataclasses import field
from typing import Any, Literal, get_args from typing import Any, Literal, get_args
@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
KVProducer = Literal["kv_producer", "kv_both"] KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"] KVConsumer = Literal["kv_consumer", "kv_both"]
@ -79,7 +79,7 @@ class KVTransferConfig:
# no factors to consider. # no factors to consider.
# this config will not affect the computation graph. # this config will not affect the computation graph.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
def __post_init__(self) -> None: def __post_init__(self) -> None:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator from pydantic import Field, field_validator
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.model_loader import LoadFormats from vllm.model_executor.model_loader import LoadFormats
@ -104,7 +104,7 @@ class LoadConfig:
# no factors to consider. # no factors to consider.
# this config will not affect the computation graph. # this config will not affect the computation graph.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@field_validator("load_format", mode="after") @field_validator("load_format", mode="after")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
import torch import torch
@ -11,6 +10,7 @@ from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -74,7 +74,7 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras) factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype) factors.append(self.lora_dtype)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@model_validator(mode="after") @model_validator(mode="after")

View File

@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf, maybe_patch_hf_config_from_gguf,
) )
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect from vllm.transformers_utils.utils import (
is_gguf,
is_remote_gguf,
maybe_model_redirect,
split_remote_gguf,
)
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
@ -294,9 +299,6 @@ class ModelConfig:
pooler_config: PoolerConfig | None = None pooler_config: PoolerConfig | None = None
"""Pooler config which controls the behaviour of output pooling in pooling """Pooler config which controls the behaviour of output pooling in pooling
models.""" models."""
override_pooler_config: dict | PoolerConfig | None = None
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
v0.12.0 or v1.0.0, whichever is sooner."""
# Multimodal config and init vars # Multimodal config and init vars
multimodal_config: MultiModalConfig | None = None multimodal_config: MultiModalConfig | None = None
@ -343,7 +345,6 @@ class ModelConfig:
"logprobs_mode", "logprobs_mode",
"disable_cascade_attn", "disable_cascade_attn",
"skip_tokenizer_init", "skip_tokenizer_init",
"enable_prompt_embeds",
"served_model_name", "served_model_name",
"config_format", "config_format",
"hf_token", "hf_token",
@ -354,7 +355,6 @@ class ModelConfig:
"logits_processors", "logits_processors",
"io_processor_plugin", "io_processor_plugin",
"pooler_config", "pooler_config",
"override_pooler_config",
"multimodal_config", "multimodal_config",
"limit_mm_per_prompt", "limit_mm_per_prompt",
"media_io_kwargs", "media_io_kwargs",
@ -440,7 +440,8 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model) self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default. # The tokenizer is consistent with the model by default.
if self.tokenizer is None: if self.tokenizer is None:
if check_gguf_file(self.model): # Check if this is a GGUF model (either local file or remote GGUF)
if is_gguf(self.model):
raise ValueError( raise ValueError(
"Using a tokenizer is mandatory when loading a GGUF model. " "Using a tokenizer is mandatory when loading a GGUF model. "
"Please specify the tokenizer path or name using the " "Please specify the tokenizer path or name using the "
@ -642,18 +643,6 @@ class ModelConfig:
# Init pooler config if needed # Init pooler config if needed
if self.runner_type == "pooling": if self.runner_type == "pooling":
if self.override_pooler_config is not None:
logger.warning_once(
"`override_pooler_config` is deprecated and will be "
"removed in v0.12.0 or v1.0.0, whichever is sooner. "
"Please use `pooler_config` instead."
)
if isinstance(self.override_pooler_config, dict):
self.pooler_config = PoolerConfig(**self.override_pooler_config)
else:
self.pooler_config = self.override_pooler_config
if self.pooler_config is None: if self.pooler_config is None:
self.pooler_config = PoolerConfig() self.pooler_config = PoolerConfig()
@ -832,7 +821,10 @@ class ModelConfig:
self.tokenizer = object_storage_tokenizer.dir self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self): def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(self.model, self.revision) model = self.model
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
return get_sentence_transformer_tokenizer_config(model, self.revision)
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias from typing import TYPE_CHECKING, Any, Literal, TypeAlias
@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
@ -216,7 +216,7 @@ class MultiModalConfig:
if self.mm_encoder_attn_backend is not None if self.mm_encoder_attn_backend is not None
else None else None
] ]
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
def get_limit_per_prompt(self, modality: str) -> int: def get_limit_per_prompt(self, modality: str) -> int:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from functools import cached_property from functools import cached_property
from typing import Any, Literal, cast from typing import Any, Literal, cast
@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
from vllm import version from vllm import version
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
DetailedTraceModules = Literal["model", "worker", "all"] DetailedTraceModules = Literal["model", "worker", "all"]
@ -78,7 +78,7 @@ class ObservabilityConfig:
# no factors to consider. # no factors to consider.
# this config will not affect the computation graph. # this config will not affect the computation graph.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@field_validator("show_hidden_metrics_for_version") @field_validator("show_hidden_metrics_for_version")

View File

@ -593,9 +593,10 @@ class ParallelConfig:
"max_parallel_loading_workers is currently " "max_parallel_loading_workers is currently "
"not supported and will be ignored." "not supported and will be ignored."
) )
if self.distributed_executor_backend != "mp" and self.nnodes > 1: if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
raise ValueError( raise ValueError(
"nnodes > 1 can only be set when distributed exectuor backend is mp." "nnodes > 1 can only be set when distributed executor "
"backend is mp or uni."
) )
@property @property

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any from typing import Any
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__) logger = init_logger(__name__)
@ -102,7 +102,7 @@ class PoolerConfig:
# no factors to consider. # no factors to consider.
# this config will not affect the computation graph. # this config will not affect the computation graph.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from collections.abc import Callable from collections.abc import Callable
from dataclasses import InitVar from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING: if TYPE_CHECKING:
@ -178,7 +178,7 @@ class SchedulerConfig:
# no factors to consider. # no factors to consider.
# this config will not affect the computation graph. # this config will not affect the computation graph.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@field_validator("scheduler_cls", "async_scheduling", mode="wrap") @field_validator("scheduler_cls", "async_scheduling", mode="wrap")

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal, get_args from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator from pydantic import Field, SkipValidation, model_validator
@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING: if TYPE_CHECKING:
@ -162,7 +162,7 @@ class SpeculativeConfig:
# Eagle3 affects the computation graph because it returns intermediate # Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state. # hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3") factors.append(self.method == "eagle3")
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@staticmethod @staticmethod

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Literal from typing import Any, Literal
from pydantic import model_validator from pydantic import model_validator
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
StructuredOutputsBackend = Literal[ StructuredOutputsBackend = Literal[
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer" "auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
@ -58,7 +58,7 @@ class StructuredOutputsConfig:
# no factors to consider. # no factors to consider.
# this config will not affect the computation graph. # this config will not affect the computation graph.
factors: list[Any] = [] factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@model_validator(mode="after") @model_validator(mode="after")

View File

@ -3,7 +3,6 @@
import copy import copy
import getpass import getpass
import hashlib
import json import json
import os import os
import tempfile import tempfile
@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .cache import CacheConfig from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
@ -193,7 +193,7 @@ class VllmConfig:
vllm_factors.append("None") vllm_factors.append("None")
if self.additional_config: if self.additional_config:
if isinstance(additional_config := self.additional_config, dict): if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5( additional_config_hash = safe_hash(
json.dumps(additional_config, sort_keys=True).encode(), json.dumps(additional_config, sort_keys=True).encode(),
usedforsecurity=False, usedforsecurity=False,
).hexdigest() ).hexdigest()
@ -204,9 +204,9 @@ class VllmConfig:
vllm_factors.append("None") vllm_factors.append("None")
factors.append(vllm_factors) factors.append(vllm_factors)
hash_str = hashlib.md5( hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
str(factors).encode(), usedforsecurity=False :10
).hexdigest()[:10] ]
return hash_str return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int: def pad_for_cudagraph(self, batch_size: int) -> int:

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole, KVConnectorRole,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
if mm_hashes: if mm_hashes:
mm_str = "-".join(mm_hashes) mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8") token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest() input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash) foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder: if create_folder:

View File

@ -51,6 +51,7 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.network_utils import get_distributed_init_method from vllm.utils.network_utils import get_distributed_init_method
from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
supports_custom_op, supports_custom_op,
@ -329,7 +330,8 @@ class GroupCoordinator:
) )
# a group with `gloo` backend, to allow direct coordination between # a group with `gloo` backend, to allow direct coordination between
# processes through the CPU. # processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo") with suppress_stdout():
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks: if self.rank in ranks:
self.ranks = ranks self.ranks = ranks
self.world_size = len(ranks) self.world_size = len(ranks)

View File

@ -30,6 +30,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.network_utils import get_tcp_uri from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -427,33 +428,34 @@ def init_gloo_process_group(
Stateless init ProcessGroup with gloo backend compatible with Stateless init ProcessGroup with gloo backend compatible with
different torch versions. different torch versions.
""" """
if is_torch_equal_or_newer("2.6"): with suppress_stdout():
pg = ProcessGroup( if is_torch_equal_or_newer("2.6"):
prefix_store, pg = ProcessGroup(
group_rank, prefix_store,
group_size, group_rank,
) group_size,
else: )
options = ProcessGroup.Options(backend="gloo") else:
pg = ProcessGroup( options = ProcessGroup.Options(backend="gloo")
prefix_store, pg = ProcessGroup(
group_rank, prefix_store,
group_size, group_rank,
options, group_size,
) options,
from torch.distributed.distributed_c10d import ProcessGroupGloo )
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo( backend_class = ProcessGroupGloo(
prefix_store, group_rank, group_size, timeout=timeout prefix_store, group_rank, group_size, timeout=timeout
) )
backend_type = ProcessGroup.BackendType.GLOO backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu") device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6 # _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type) pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group() backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class) pg._register_backend(device, backend_type, backend_class)
return pg return pg

View File

@ -29,7 +29,7 @@ import regex as re
import torch import torch
from pydantic import TypeAdapter, ValidationError from pydantic import TypeAdapter, ValidationError
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import TypeIs, deprecated from typing_extensions import TypeIs
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
@ -86,7 +86,7 @@ from vllm.transformers_utils.config import (
is_interleaved, is_interleaved,
maybe_override_with_speculators, maybe_override_with_speculators,
) )
from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
@ -520,9 +520,6 @@ class EngineArgs:
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
pooler_config: PoolerConfig | None = ModelConfig.pooler_config pooler_config: PoolerConfig | None = ModelConfig.pooler_config
override_pooler_config: dict | PoolerConfig | None = (
ModelConfig.override_pooler_config
)
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
worker_cls: str = ParallelConfig.worker_cls worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls
@ -659,11 +656,6 @@ class EngineArgs:
) )
model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"]) model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
model_group.add_argument(
"--override-pooler-config",
**model_kwargs["override_pooler_config"],
deprecated=True,
)
model_group.add_argument( model_group.add_argument(
"--logits-processor-pattern", **model_kwargs["logits_processor_pattern"] "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
) )
@ -880,7 +872,11 @@ class EngineArgs:
"--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"] "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
) )
cache_group.add_argument( cache_group.add_argument(
"--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"] "--enable-prefix-caching",
**{
**cache_kwargs["enable_prefix_caching"],
"default": None,
},
) )
cache_group.add_argument( cache_group.add_argument(
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"] "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
@ -1144,8 +1140,8 @@ class EngineArgs:
return engine_args return engine_args
def create_model_config(self) -> ModelConfig: def create_model_config(self) -> ModelConfig:
# gguf file needs a specific model loader and doesn't use hf_repo # gguf file needs a specific model loader
if check_gguf_file(self.model): if is_gguf(self.model):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless # NOTE(woosuk): In V1, we use separate processes for workers (unless
@ -1239,7 +1235,6 @@ class EngineArgs:
mm_encoder_tp_mode=self.mm_encoder_tp_mode, mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend, mm_encoder_attn_backend=self.mm_encoder_attn_backend,
pooler_config=self.pooler_config, pooler_config=self.pooler_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern, logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config, generation_config=self.generation_config,
override_generation_config=self.override_generation_config, override_generation_config=self.override_generation_config,
@ -1812,9 +1807,11 @@ class EngineArgs:
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
default_chunked_prefill = True default_chunked_prefill = True
# Disable prefix caching default for hybrid models # Disable prefix caching default for hybrid models and mamba-only
# since the feature is still experimental. # models since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid default_prefix_caching = not (
model_config.is_hybrid or model_config.is_attention_free
)
else: else:
assert model_config.pooler_config is not None assert model_config.pooler_config is not None
@ -2041,24 +2038,6 @@ class AsyncEngineArgs(EngineArgs):
enable_log_requests: bool = False enable_log_requests: bool = False
@property
@deprecated(
"`disable_log_requests` is deprecated and has been replaced with "
"`enable_log_requests`. This will be removed in v0.12.0. Please use "
"`enable_log_requests` instead."
)
def disable_log_requests(self) -> bool:
return not self.enable_log_requests
@disable_log_requests.setter
@deprecated(
"`disable_log_requests` is deprecated and has been replaced with "
"`enable_log_requests`. This will be removed in v0.12.0. Please use "
"`enable_log_requests` instead."
)
def disable_log_requests(self, value: bool):
self.enable_log_requests = not value
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: FlexibleArgumentParser, async_args_only: bool = False parser: FlexibleArgumentParser, async_args_only: bool = False

View File

@ -174,9 +174,6 @@ class LLM:
For example, for Phi-3-Vision: `{"num_crops": 4}`. For example, for Phi-3-Vision: `{"num_crops": 4}`.
pooler_config: Initialize non-default pooling config for the pooling pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
whichever is sooner.
compilation_config: Either an integer or a dictionary. If it is an compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration. is a dictionary, it can specify the full compilation configuration.
@ -214,7 +211,6 @@ class LLM:
hf_overrides: HfOverrides | None = None, hf_overrides: HfOverrides | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
pooler_config: PoolerConfig | None = None, pooler_config: PoolerConfig | None = None,
override_pooler_config: PoolerConfig | None = None,
structured_outputs_config: dict[str, Any] structured_outputs_config: dict[str, Any]
| StructuredOutputsConfig | StructuredOutputsConfig
| None = None, | None = None,
@ -330,7 +326,6 @@ class LLM:
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
pooler_config=pooler_config, pooler_config=pooler_config,
override_pooler_config=override_pooler_config,
structured_outputs_config=structured_outputs_instance, structured_outputs_config=structured_outputs_instance,
compilation_config=compilation_config_instance, compilation_config=compilation_config_instance,
logits_processors=logits_processors, logits_processors=logits_processors,

View File

@ -29,7 +29,6 @@ from openai.types.responses import (
ResponseOutputItemAddedEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputItemDoneEvent,
ResponsePrompt, ResponsePrompt,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent, ResponseReasoningTextDoneEvent,
ResponseStatus, ResponseStatus,
@ -304,9 +303,7 @@ def get_logits_processors(
return None return None
ResponseInputOutputItem: TypeAlias = ( ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
)
class ResponsesRequest(OpenAIBaseModel): class ResponsesRequest(OpenAIBaseModel):

View File

@ -10,6 +10,9 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
Function as FunctionCallTool, Function as FunctionCallTool,
) )
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool
@ -94,6 +97,12 @@ def construct_chat_message_with_tool_call(
"role": "assistant", "role": "assistant",
"reasoning": reasoning_content, "reasoning": reasoning_content,
} }
elif isinstance(item, ResponseFunctionToolCallOutputItem):
return ChatCompletionToolMessageParam(
role="tool",
content=item.output,
tool_call_id=item.call_id,
)
elif item.get("type") == "function_call_output": elif item.get("type") == "function_call_output":
# Append the function call output as a tool message. # Append the function call output as a tool message.
return ChatCompletionToolMessageParam( return ChatCompletionToolMessageParam(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP, TopKWeightAndReduceNoOP,
) )
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels from vllm.utils.import_utils import has_triton_kernels
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output, topk, sm_first=not renormalize gating_output, topk, sm_first=not renormalize
) )
output = torch.empty_like(hidden_states)
return triton_kernel_fused_experts( return triton_kernel_fused_experts(
None, output,
hidden_states, hidden_states,
w1, w1,
w2, w2,
routing_data, routing_data,
gather_idx, gather_idx,
scatter_idx, scatter_idx,
topk=topk,
activation=activation, activation=activation,
quant_config=quant_config, quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
routing_data, # RoutingData routing_data, # RoutingData
gather_indx, # GatherIndx gather_indx, # GatherIndx
scatter_indx, # ScatterIndx scatter_indx, # ScatterIndx
topk: int,
activation: str = "silu", activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702, swiglu_alpha: float = 1.702,
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
intermediate_cache: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4 # Shape check, only check non-mxfp4
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2] assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1] assert w2.shape[-1] == w1.shape[1]
batch_dim = 1
M, K = hidden_states.shape[-2:]
E, _, N = w1.shape E, _, N = w1.shape
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
if intermediate_cache is None:
intermediate_cache = torch.empty(
(batch_dim, M * topk, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache = _resize_cache(
intermediate_cache, (batch_dim, M * topk, N // 2)
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
act = FusedActivation( act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit), (swiglu_alpha, swiglu_limit),
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
) )
gammas = routing_data.gate_scal if routing_data else None gammas = routing_data.gate_scal if routing_data else None
intermediate_cache1 = matmul_ogs( matmul_ogs(
hidden_states, hidden_states,
w1, w1,
quant_config.w1_bias, quant_config.w1_bias,
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config=quant_config.w1_precision, precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None, gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act, fused_activation=act,
y=intermediate_cache,
) )
intermediate_cache3 = matmul_ogs( matmul_ogs(
intermediate_cache1, intermediate_cache.view(M * topk, N // 2),
w2, w2,
quant_config.w2_bias, quant_config.w2_bias,
routing_data, routing_data,
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas, gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor, y=output_tensor,
) )
return intermediate_cache3 output_tensor = output_tensor.view(M, K)
return output_tensor
def make_routing_data( def make_routing_data(
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, _, N = w1.size()
K = a1.size(-1)
assert a1.dim() == 2
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel. # Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP() return TopKWeightAndReduceNoOP()
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
workspace1 = (M, K) workspace1 = (0, 0)
workspace2 = (0, 0) workspace2 = (M * topk, N // 2)
output = (M, K) output = (M, K)
return (workspace1, workspace2, output) return (workspace1, workspace2, output)
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids, topk_weights, local_num_experts topk_ids, topk_weights, local_num_experts
) )
experts_output = triton_kernel_fused_experts( topk = topk_ids.size(1)
None, triton_kernel_fused_experts(
output,
hidden_states, hidden_states,
w1, w1,
w2, w2,
routing_data, routing_data,
gather_indx, gather_indx,
scatter_indx, scatter_indx,
topk=topk,
activation=activation, activation=activation,
quant_config=self.quant_config, quant_config=self.quant_config,
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
global_num_experts=local_num_experts, global_num_experts=local_num_experts,
expert_map=None, # applied already expert_map=None, # applied already
intermediate_cache=workspace2,
a1q_scale=a1q_scale, a1q_scale=a1q_scale,
) )
output.copy_(experts_output, non_blocking=True)

View File

@ -103,7 +103,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A4Nvfp4MoeMethod",
"CompressedTensorsW4A8Int8MoEMethod", "CompressedTensorsW4A8Int8MoEMethod",
] ]
@ -171,7 +171,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
quant_config, layer.moe_config quant_config, layer.moe_config
) )
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod(layer.moe_config) return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
elif ( elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
@ -188,7 +188,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
) )
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support, detect_nvfp4_moe_support,
@ -205,8 +205,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once( logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4MoeMethod." " for CompressedTensorsW4A4Nvfp4MoeMethod."
) )
elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
def create_weights( def create_weights(
self, self,
@ -612,7 +616,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, ( assert expert_map is None, (
"Expert Parallelism / expert_map " "Expert Parallelism / expert_map "
"is currently not supported for " "is currently not supported for "
"CompressedTensorsW4A4MoeMethod." "CompressedTensorsW4A4Nvfp4MoeMethod."
) )
assert self.moe_quant_config is not None assert self.moe_quant_config is not None

View File

@ -1132,6 +1132,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for ModelOptNvFp4FusedMoE." " for ModelOptNvFp4FusedMoE."
) )
elif self.use_marlin:
logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
else:
logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,

View File

@ -196,9 +196,10 @@ class Mxfp4Config(QuantizationConfig):
# TODO: Add support for MXFP4 Linear Method. # TODO: Add support for MXFP4 Linear Method.
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
# if you are interested in enabling MXFP4 here. # if you are interested in enabling MXFP4 here.
logger.warning_once( logger.debug_once(
"MXFP4 linear layer is not implemented - falling back to " "MXFP4 linear layer is not implemented - falling back to "
"UnquantizedLinearMethod." "UnquantizedLinearMethod.",
scope="local",
) )
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
@ -208,9 +209,10 @@ class Mxfp4Config(QuantizationConfig):
return Mxfp4MoEMethod(layer.moe_config) return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention. # TODO: Add support for MXFP4 Attention.
logger.warning_once( logger.debug_once(
"MXFP4 attention layer is not implemented. " "MXFP4 attention layer is not implemented. "
"Skipping quantization for this layer." "Skipping quantization for this layer.",
scope="local",
) )
return None return None

View File

@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, process_weights_after_loading,
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_gguf,
get_gguf_extra_tensor_names, get_gguf_extra_tensor_names,
get_gguf_weight_type_map, get_gguf_weight_type_map,
gguf_quant_weights_iterator, gguf_quant_weights_iterator,
@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
f"load format {load_config.load_format}" f"load format {load_config.load_format}"
) )
def _prepare_weights(self, model_name_or_path: str): def _prepare_weights(self, model_config: ModelConfig):
model_name_or_path = model_config.model
if os.path.isfile(model_name_or_path): if os.path.isfile(model_name_or_path):
return model_name_or_path return model_name_or_path
# for raw HTTPS link # for raw HTTPS link
@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader):
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1) repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename) return hf_hub_download(repo_id=repo_id, filename=filename)
else: # repo_id:quant_type
raise ValueError( elif "/" in model_name_or_path and ":" in model_name_or_path:
f"Unrecognised GGUF reference: {model_name_or_path} " repo_id, quant_type = model_name_or_path.rsplit(":", 1)
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)" return download_gguf(
repo_id,
quant_type,
cache_dir=self.load_config.download_dir,
revision=model_config.revision,
ignore_patterns=self.load_config.ignore_patterns,
) )
raise ValueError(
f"Unrecognised GGUF reference: {model_name_or_path} "
"(expected local file, raw URL, <repo_id>/<filename>.gguf, "
"or <repo_id>:<quant_type>)"
)
def _get_gguf_weights_map(self, model_config: ModelConfig): def _get_gguf_weights_map(self, model_config: ModelConfig):
""" """
GGUF uses this naming convention for their tensors from HF checkpoint: GGUF uses this naming convention for their tensors from HF checkpoint:
@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map: dict[str, str], gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]: ) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map( weight_type_map = get_gguf_weight_type_map(
model_config.model, gguf_to_hf_name_map model_name_or_path, gguf_to_hf_name_map
) )
is_multimodal = hasattr(model_config.hf_config, "vision_config") is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal: if is_multimodal:
@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model) self._prepare_weights(model_config)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model) local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map) self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
self, vllm_config: VllmConfig, model_config: ModelConfig self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module: ) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model) local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights # we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names( if "lm_head.weight" in get_gguf_extra_tensor_names(

View File

@ -369,6 +369,52 @@ def get_sparse_attention_config(
return config return config
def download_gguf(
repo_id: str,
quant_type: str,
cache_dir: str | None = None,
revision: str | None = None,
ignore_patterns: str | list[str] | None = None,
) -> str:
# Use patterns that snapshot_download can handle directly
# Patterns to match:
# - *-{quant_type}.gguf (root)
# - *-{quant_type}-*.gguf (root sharded)
# - */*-{quant_type}.gguf (subdir)
# - */*-{quant_type}-*.gguf (subdir sharded)
allow_patterns = [
f"*-{quant_type}.gguf",
f"*-{quant_type}-*.gguf",
f"*/*-{quant_type}.gguf",
f"*/*-{quant_type}-*.gguf",
]
# Use download_weights_from_hf which handles caching and downloading
folder = download_weights_from_hf(
model_name_or_path=repo_id,
cache_dir=cache_dir,
allow_patterns=allow_patterns,
revision=revision,
ignore_patterns=ignore_patterns,
)
# Find the downloaded file(s) in the folder
local_files = []
for pattern in allow_patterns:
# Convert pattern to glob pattern for local filesystem
glob_pattern = os.path.join(folder, pattern)
local_files.extend(glob.glob(glob_pattern))
if not local_files:
raise ValueError(
f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
)
# Sort to ensure consistent ordering (prefer non-sharded files)
local_files.sort(key=lambda x: (x.count("-"), x))
return local_files[0]
def download_weights_from_hf( def download_weights_from_hf(
model_name_or_path: str, model_name_or_path: str,
cache_dir: str | None, cache_dir: str | None,

View File

@ -233,7 +233,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding, position_embedding=position_embedding,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,

View File

@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching: if model_config.supports_mamba_prefix_caching:
logger.info( logger.info(
@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
"Its support for Mamba layers is experimental. " "Its support for Mamba layers is experimental. "
"Please report any issues you may observe." "Please report any issues you may observe."
) )
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else: else:
logger.info( logger.info(
"Hybrid or mamba-based model detected without " "Hybrid or mamba-based model detected without "
@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
) )
cache_config.enable_prefix_caching = False cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
# TODO(tdoublep): remove once cascade attention is supported # TODO(tdoublep): remove once cascade attention is supported
logger.info( logger.info(
"Disabling cascade attention since it is not supported for hybrid models." "Disabling cascade attention since it is not supported for hybrid models."

View File

@ -100,7 +100,7 @@ class GPTJAttention(nn.Module):
self.head_size, self.head_size,
rotary_dim=config.rotary_dim, rotary_dim=config.rotary_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(

View File

@ -239,7 +239,7 @@ class Grok1DecoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",

View File

@ -262,7 +262,7 @@ class LlamaAttention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor, partial_rotary_factor=self.partial_rotary_factor,
) )

View File

@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it. `tests/models/registry.py` with example HuggingFace models for it.
""" """
import hashlib
import importlib import importlib
import json import json
import os import os
@ -32,6 +31,7 @@ from vllm.config import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logging_utils import logtime from vllm.logging_utils import logtime
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash
from .interfaces import ( from .interfaces import (
has_inner_state, has_inner_state,
@ -655,7 +655,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
if model_path.exists(): if model_path.exists():
with open(model_path, "rb") as f: with open(model_path, "rb") as f:
module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest() module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
mi = self._load_modelinfo_from_cache(module_hash) mi = self._load_modelinfo_from_cache(module_hash)
if mi is not None: if mi is not None:

View File

@ -407,9 +407,6 @@ class CudaPlatformBase(Platform):
# We have found some valid backends. Select the one with the # We have found some valid backends. Select the one with the
# highest priority. # highest priority.
logger.info(
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
)
sorted_indices = sorted( sorted_indices = sorted(
range(len(valid_backends_priorities)), range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1], key=lambda i: valid_backends_priorities[i][1],
@ -417,8 +414,9 @@ class CudaPlatformBase(Platform):
selected_index = sorted_indices[0] selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0] selected_backend = valid_backends_priorities[selected_index][0]
logger.info( logger.info(
"Using %s backend.", "Using %s attention backend out of potential backends: %s",
selected_backend.name, selected_backend.name,
[b[0].name for b in valid_backends_priorities],
) )
return selected_backend.get_path() return selected_backend.get_path()

View File

@ -42,7 +42,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import ( from vllm.transformers_utils.utils import (
check_gguf_file, check_gguf_file,
is_gguf,
is_remote_gguf,
parse_safetensors_file_metadata, parse_safetensors_file_metadata,
split_remote_gguf,
) )
if envs.VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
@ -453,51 +456,55 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No
def patch_rope_parameters(config: PretrainedConfig) -> None: def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE.""" """Provide backwards compatibility for RoPE."""
# Retrieve rope_parameters differently based on Transformers version # Patch rope_parameters differently based on Transformers version
if Version(version("transformers")) >= Version("5.0.0.dev0"): if Version(version("transformers")) >= Version("5.0.0.dev0"):
from transformers.modeling_rope_utils import RopeParameters from transformers.modeling_rope_utils import (
rope_config_validation,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr( standardize_rope_params,
config, "rope_parameters", None
) )
elif hasattr(config, "rope_parameters"):
# We are in Transformers v4 and rope_parameters # When Transformers v5 is installed, legacy rope_theta may be present
# has already been patched for this config # when using custom code models written for Transformers v4
return if (rope_theta := getattr(config, "rope_theta", None)) is not None:
standardize_rope_params(config, rope_theta=rope_theta)
rope_config_validation(config)
# Delete rope_theta to avoid confusion in downstream code
del config.rope_theta
else: else:
# Convert Transformers v4 rope_theta and rope_scaling into rope_parameters # When Transformers v4 is installed, legacy rope_scaling may be present
rope_theta: float | None = getattr(config, "rope_theta", None) if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
rope_scaling: dict | None = getattr(config, "rope_scaling", None) config.rope_parameters = rope_scaling
rope_parameters = rope_scaling # When Transformers v4 is installed, legacy rope_theta may be present
# Move rope_theta into rope_parameters if (rope_theta := getattr(config, "rope_theta", None)) is not None:
if rope_theta is not None: if not hasattr(config, "rope_parameters"):
rope_parameters = rope_parameters or {"rope_type": "default"} config.rope_parameters = {"rope_type": "default"}
rope_parameters["rope_theta"] = rope_theta config.rope_parameters["rope_theta"] = rope_theta
# Add original_max_position_embeddings if present
if rope_parameters and (
ompe := getattr(config, "original_max_position_embeddings", None)
):
rope_parameters["original_max_position_embeddings"] = ompe
# Write back to config
config.rope_parameters = rope_parameters
# No RoPE parameters to patch # No RoPE parameters to patch
if rope_parameters is None: if not hasattr(config, "rope_parameters"):
return return
# Add original_max_position_embeddings if present
if ompe := getattr(config, "original_max_position_embeddings", None):
config.rope_parameters["original_max_position_embeddings"] = ompe
# Handle nested rope_parameters in interleaved sliding attention models # Handle nested rope_parameters in interleaved sliding attention models
if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES): if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
for rope_parameters_layer_type in rope_parameters.values(): for rope_parameters_layer_type in config.rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type) patch_rope_parameters_dict(rope_parameters_layer_type)
else: else:
patch_rope_parameters_dict(rope_parameters) patch_rope_parameters_dict(config.rope_parameters)
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None: def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
if "rope_type" in rope_parameters and "type" in rope_parameters: if "rope_type" in rope_parameters and "type" in rope_parameters:
rope_type = rope_parameters["rope_type"] rope_type = rope_parameters["rope_type"]
rope_type_legacy = rope_parameters["type"] rope_type_legacy = rope_parameters["type"]
if rope_type != rope_type_legacy: if (rope_type_legacy == "su" and rope_type == "longrope") or (
rope_type_legacy == "mrope" and rope_type == "default"
):
pass # No action needed
elif rope_type != rope_type_legacy:
raise ValueError( raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern " f"Found conflicts between 'rope_type={rope_type}' (modern "
f"field) and 'type={rope_type_legacy}' (legacy field). " f"field) and 'type={rope_type_legacy}' (legacy field). "
@ -629,10 +636,12 @@ def maybe_override_with_speculators(
Returns: Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config) Tuple of (resolved_model, resolved_tokenizer, speculative_config)
""" """
is_gguf = check_gguf_file(model) if check_gguf_file(model):
if is_gguf:
kwargs["gguf_file"] = Path(model).name kwargs["gguf_file"] = Path(model).name
gguf_model_repo = Path(model).parent gguf_model_repo = Path(model).parent
elif is_remote_gguf(model):
repo_id, _ = split_remote_gguf(model)
gguf_model_repo = Path(repo_id)
else: else:
gguf_model_repo = None gguf_model_repo = None
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
@ -678,10 +687,18 @@ def get_config(
) -> PretrainedConfig: ) -> PretrainedConfig:
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(model) _is_gguf = is_gguf(model)
if is_gguf: _is_remote_gguf = is_remote_gguf(model)
kwargs["gguf_file"] = Path(model).name if _is_gguf:
model = Path(model).parent if check_gguf_file(model):
# Local GGUF file
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent
elif _is_remote_gguf:
# Remote GGUF - extract repo_id from repo_id:quant_type format
# The actual GGUF file will be downloaded later by GGUFModelLoader
# Keep model as repo_id:quant_type for download, but use repo_id for config
model, _ = split_remote_gguf(model)
if config_format == "auto": if config_format == "auto":
try: try:
@ -689,10 +706,25 @@ def get_config(
# Transformers implementation. # Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral" config_format = "mistral"
elif is_gguf or file_or_path_exists( elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision model, HF_CONFIG_NAME, revision=revision
): ):
config_format = "hf" config_format = "hf"
# Remote GGUF models must have config.json in repo,
# otherwise the config can't be parsed correctly.
# FIXME(Isotr0py): Support remote GGUF repos without config.json
elif _is_remote_gguf and not file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
err_msg = (
"Could not find config.json for remote GGUF model repo. "
"To load remote GGUF model through `<repo_id>:<quant_type>`, "
"ensure your model has config.json (HF format) file. "
"Otherwise please specify --hf-config-path <original_repo> "
"in engine args to fetch config from unquantized hf model."
)
logger.error(err_msg)
raise ValueError(err_msg)
else: else:
raise ValueError( raise ValueError(
"Could not detect config format for no config file found. " "Could not detect config format for no config file found. "
@ -713,9 +745,6 @@ def get_config(
"'config.json'.\n" "'config.json'.\n"
" - For Mistral models: ensure the presence of a " " - For Mistral models: ensure the presence of a "
"'params.json'.\n" "'params.json'.\n"
"3. For GGUF: pass the local path of the GGUF checkpoint.\n"
" Loading GGUF from a remote repo directly is not yet "
"supported.\n"
).format(model=model) ).format(model=model)
raise ValueError(error_message) from e raise ValueError(error_message) from e
@ -729,7 +758,7 @@ def get_config(
**kwargs, **kwargs,
) )
# Special architecture mapping check for GGUF models # Special architecture mapping check for GGUF models
if is_gguf: if _is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.") raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
@ -889,6 +918,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
A dictionary containing the pooling type and whether A dictionary containing the pooling type and whether
normalization is used, or None if no pooling configuration is found. normalization is used, or None if no pooling configuration is found.
""" """
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
modules_file_name = "modules.json" modules_file_name = "modules.json"
@ -1108,6 +1139,8 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
if check_gguf_file(model): if check_gguf_file(model):
model = Path(model).parent model = Path(model).parent
elif is_remote_gguf(model):
model, _ = split_remote_gguf(model)
return get_image_processor_config( return get_image_processor_config(
model, token=hf_token, revision=revision, **kwargs model, token=hf_token, revision=revision, **kwargs
) )

View File

@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING: if TYPE_CHECKING:
@ -236,8 +236,8 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any, **kwargs: Any,
) -> _P: ) -> _P:
if check_gguf_file(model_config.model): if is_gguf(model_config.model):
assert not check_gguf_file(model_config.tokenizer), ( assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer " "For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor." "should be used to correctly load processor."
) )
@ -350,8 +350,8 @@ def cached_image_processor_from_config(
model_config: "ModelConfig", model_config: "ModelConfig",
**kwargs: Any, **kwargs: Any,
): ):
if check_gguf_file(model_config.model): if is_gguf(model_config.model):
assert not check_gguf_file(model_config.tokenizer), ( assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer " "For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor." "should be used to correctly load image processor."
) )

View File

@ -20,7 +20,12 @@ from vllm.transformers_utils.config import (
list_filtered_repo_files, list_filtered_repo_files,
) )
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
split_remote_gguf,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -180,10 +185,12 @@ def get_tokenizer(
kwargs["truncation_side"] = "left" kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(tokenizer_name) if is_gguf(tokenizer_name):
if is_gguf: if check_gguf_file(tokenizer_name):
kwargs["gguf_file"] = Path(tokenizer_name).name kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent tokenizer_name = Path(tokenizer_name).parent
elif is_remote_gguf(tokenizer_name):
tokenizer_name, _ = split_remote_gguf(tokenizer_name)
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
# first to use official Mistral tokenizer if possible. # first to use official Mistral tokenizer if possible.

View File

@ -9,6 +9,8 @@ from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from gguf import GGMLQuantizationType
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
return False return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
model = str(model)
return (
(not is_cloud_storage(model))
and (not model.startswith(("http://", "https://")))
and ("/" in model and ":" in model)
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
)
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
"- It should be in repo_id:quant_type format.\n"
"- Valid GGMLQuantizationType values: %s",
model,
GGMLQuantizationType._member_names_,
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def modelscope_list_repo_files( def modelscope_list_repo_files(
repo_id: str, repo_id: str,
revision: str | None = None, revision: str | None = None,

View File

@ -73,14 +73,6 @@ class FlexibleArgumentParser(ArgumentParser):
# Enable the deprecated kwarg for Python 3.12 and below # Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None): def parse_known_args(self, args=None, namespace=None):
if args is not None and "--disable-log-requests" in args:
# Special case warning because the warning below won't trigger
# if -disable-log-requests because its value is default.
logger.warning_once(
"argument '--disable-log-requests' is deprecated and "
"replaced with '--enable-log-requests'. This will be "
"removed in v0.12.0."
)
namespace, args = super().parse_known_args(args, namespace) namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated: for action in FlexibleArgumentParser._deprecated:
if ( if (

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import hashlib import hashlib
import pickle import pickle
from _hashlib import HASH, UnsupportedDigestmodError
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256_cbor return sha256_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}") raise ValueError(f"Unsupported hash function: {hash_fn_name}")
def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH:
"""Hash for configs, defaulting to md5 but falling back to sha256
in FIPS constrained environments.
Args:
data: bytes
usedforsecurity: Whether the hash is used for security purposes
Returns:
Hash object
"""
try:
return hashlib.md5(data, usedforsecurity=usedforsecurity)
except (UnsupportedDigestmodError, ValueError):
return hashlib.sha256(data)

View File

@ -56,6 +56,39 @@ def set_env_var(key: str, value: str) -> Iterator[None]:
os.environ[key] = old os.environ[key] = old
@contextlib.contextmanager
def suppress_stdout():
"""
Suppress stdout from C libraries at the file descriptor level.
Only suppresses stdout, not stderr, to preserve error messages.
Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG.
Example:
with suppress_stdout():
# C library calls that would normally print to stdout
torch.distributed.new_group(ranks, backend="gloo")
"""
# Don't suppress if logging level is DEBUG
if envs.VLLM_LOGGING_LEVEL == "DEBUG":
yield
return
stdout_fd = sys.stdout.fileno()
stdout_dup = os.dup(stdout_fd)
devnull_fd = os.open(os.devnull, os.O_WRONLY)
try:
sys.stdout.flush()
os.dup2(devnull_fd, stdout_fd)
yield
finally:
sys.stdout.flush()
os.dup2(stdout_dup, stdout_fd)
os.close(stdout_dup)
os.close(devnull_fd)
# File path utilities # File path utilities

View File

@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ( from vllm.v1.core.kv_cache_utils import (
BlockHash, BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId, BlockHashWithGroupId,
ExternalBlockHash, ExternalBlockHash,
FreeKVCacheBlockQueue, FreeKVCacheBlockQueue,
@ -133,6 +135,10 @@ class BlockPool:
Args: Args:
num_gpu_blocks: The number of blocks in the pool. num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching. enable_caching: Whether to enable prefix caching.
hash_block_size: The block size of which the block hashes are computed.
The actual block size usually equals hash_block_size, but in cases
where different KV cache groups have different block sizes, the
actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events. enable_kv_cache_events: Whether to enable kv cache events.
""" """
@ -140,11 +146,13 @@ class BlockPool:
self, self,
num_gpu_blocks: int, num_gpu_blocks: int,
enable_caching: bool, enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
): ):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.hash_block_size = hash_block_size
# All kv-cache blocks. # All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [ self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks) KVCacheBlock(idx) for idx in range(num_gpu_blocks)
@ -223,8 +231,20 @@ class BlockPool:
return return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks] new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:] if block_size == self.hash_block_size:
# Common case.
block_hashes: BlockHashList = request.block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when
# different KV cache groups have different block sizes.
assert block_size % self.hash_block_size == 0
# Recalculate block_hashes at the granularity of block_size, using
# the original block_hashes (at the granularity of hash_block_size).
block_hashes = BlockHashListWithBlockSize(
request.block_hashes, self.hash_block_size, block_size
)
new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = ( new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None [] if self.enable_kv_cache_events else None
) )

View File

@ -2,15 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
KVCacheBlock,
)
from vllm.v1.core.single_type_kv_cache_manager import ( from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager, CrossAttentionManager,
FullAttentionManager, FullAttentionManager,
get_manager_for_kv_cache_spec, get_manager_for_kv_cache_spec,
) )
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.request import Request from vllm.v1.request import Request
@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.block_pool = BlockPool( self.block_pool = BlockPool(
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events kv_cache_config.num_blocks,
enable_caching,
hash_block_size,
enable_kv_cache_events,
) )
# Needs special handling for find_longest_cache_hit if eagle is enabled # Needs special handling for find_longest_cache_hit if eagle is enabled
@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
self.num_single_type_manager = len(self.single_type_managers) self.num_single_type_manager = len(self.single_type_managers)
@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size
@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self.block_size *= dcp_world_size self.block_size *= dcp_world_size
if pcp_world_size > 1: if pcp_world_size > 1:
self.block_size *= pcp_world_size self.block_size *= pcp_world_size
# For models using only Mamba, block_size is set to max_model_len when
# prefix caching is disabled, and hash_block_size validation is skipped.
assert not enable_caching or (hash_block_size == self.block_size), (
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
)
assert len(self.kv_cache_config.kv_cache_groups) == 1, ( assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group" "UnitaryKVCacheCoordinator assumes only one kv cache group"
) )
@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec, kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size, pcp_world_size=self.pcp_world_size,
) )
@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
# hash_block_size: the block size used to compute block hashes.
# The actual block size usually equals hash_block_size, but in cases where
# different KV cache groups have different block sizes, the actual block size
# can be a multiple of hash_block_size.
self.hash_block_size = hash_block_size
assert all(
g.kv_cache_spec.block_size % hash_block_size == 0
for g in kv_cache_config.kv_cache_groups
), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now." assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now." assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups() self.verify_and_split_kv_cache_groups()
@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.other_spec = other_spec self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size self.other_block_size = self.other_spec.block_size
# The LCM of the block sizes of full attention and other attention.
if self.enable_caching: # The cache hit length must be a multiple of the LCM of the block sizes
# this requirement is only needed for the prefix caching logic # to make sure the cache hit length is a multiple of the block size of
divisible = self.other_block_size % self.full_attention_block_size # each attention type. Requiring this because we don't support partial
assert divisible == 0, ( # block cache hit yet.
"KVCacheCoordinator assumes the block_size of full " self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
"attention layers is divisible by other layers now."
)
if max(self.full_attention_group_ids) < min(self.other_group_ids): if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True self.full_attn_first = True
@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
- The number of tokens of the longest cache hit. - The number of tokens of the longest cache hit.
""" """
# First, find the longest cache hit for full attention. # First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size:
# Common case.
full_attention_block_hashes: BlockHashList = block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when different
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes, block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length, max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids, kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec, kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
) )
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN # Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention. # the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
)
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes, block_hashes=other_block_hashes,
max_length=hit_length, max_length=hit_length,
kv_cache_group_ids=self.other_group_ids, kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.other_spec, kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
) )
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
) -> KVCacheCoordinator: ) -> KVCacheCoordinator:
if not enable_caching: if not enable_caching:
return KVCacheCoordinatorNoPrefixCache( return KVCacheCoordinatorNoPrefixCache(
@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
max_model_len, max_model_len,
use_eagle, use_eagle,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size,
hash_block_size,
) )
if len(kv_cache_config.kv_cache_groups) == 1: if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator( return UnitaryKVCacheCoordinator(
@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size,
hash_block_size,
) )
return HybridKVCacheCoordinator( return HybridKVCacheCoordinator(
kv_cache_config, kv_cache_config,
@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size,
hash_block_size,
) )

View File

@ -95,6 +95,7 @@ class KVCacheManager:
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
hash_block_size: int,
enable_caching: bool = True, enable_caching: bool = True,
use_eagle: bool = False, use_eagle: bool = False,
log_stats: bool = False, log_stats: bool = False,
@ -107,28 +108,11 @@ class KVCacheManager:
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.use_eagle = use_eagle self.use_eagle = use_eagle
self.log_stats = log_stats self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats # FIXME: make prefix cache stats conditional on log_stats. We still need
# this comment because when the log stats is enabled there are still
# potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.block_size: int | None = None
if self.enable_caching:
assert (
len(
set(
g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups
)
)
== 1
), "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0
].kv_cache_spec.block_size
if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator( self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
@ -137,6 +121,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool self.block_pool = self.coordinator.block_pool

View File

@ -5,9 +5,9 @@
import copy import copy
import os import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass from dataclasses import dataclass, replace
from typing import Any, NewType, TypeAlias from typing import Any, NewType, TypeAlias, overload
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -825,11 +825,11 @@ def get_num_blocks(
return num_blocks return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
""" """
Get the page size of the KV cache. Get the page size of the KV cache.
""" """
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1 assert len(page_sizes) == 1
return page_sizes.pop() return page_sizes.pop()
@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
return len(page_sizes) == 1 return len(page_sizes) == 1
def unify_kv_cache_spec_page_size(
kv_cache_spec: dict[str, KVCacheSpec],
) -> dict[str, KVCacheSpec]:
"""
Unify the page size of the given KVCacheSpec. If the page size of all layers
are the same, return the original KVCacheSpec. If not same, unify the page
size by increasing the block size of layers with smaller page size. Raise
NotImplementedError if failed to unify the page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The updated KVCacheSpec with the same page_size_bytes.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
if len(page_sizes) <= 1:
# All layers have the same page size, no need to unify.
return kv_cache_spec
max_page_size = max(page_sizes)
new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
new_kv_cache_spec[layer_name] = layer_spec
else:
layer_page_size = layer_spec.page_size_bytes
if max_page_size % layer_page_size != 0:
raise NotImplementedError(
"The page size of the layer is not divisible by the "
"maximum page size. Cannot unify by adjusting block_size."
)
ratio = max_page_size // layer_page_size
new_block_size = layer_spec.block_size * ratio
new_spec = replace(layer_spec, block_size=new_block_size)
assert new_spec.page_size_bytes == max_page_size
new_kv_cache_spec[layer_name] = new_spec
return new_kv_cache_spec
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
# kv_cache_spec is an empty dict for attention free models # kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec return not kv_cache_spec
@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
def get_kv_cache_config_from_groups( def get_kv_cache_config_from_groups(
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec], kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int, available_memory: int,
) -> KVCacheConfig: ) -> KVCacheConfig:
""" """
@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
Args: Args:
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes available_memory: Memory available for KV cache in bytes
Returns: Returns:
The generated KVCacheConfig The generated KVCacheConfig
@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups) group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs) page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
assert group_size > 0, "group_size must be greater than 0" assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks( num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size vllm_config, group_size, available_memory, page_size
@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle # This returns an empty list to allow for the KVCacheManager to handle
# attention free models. # attention free models.
return [] return []
elif is_kv_cache_spec_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for # KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for # most models. Allocate the same amount of memory for
# each layer. # each layer.
@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the # full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group. # same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec) return _get_kv_cache_groups_uniform_type(uniform_spec)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
raise NotImplementedError # As KVCacheManager can only allocate memory of one size, we need to unify
# the page size of the layers. For cases cannot be unified, this function
# will raise an error.
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config( def generate_scheduler_kv_cache_config(
@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append( kv_cache_configs.append(
get_kv_cache_config_from_groups( get_kv_cache_config_from_groups(
vllm_config, vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
kv_cache_groups_one_worker,
kv_cache_spec_one_worker,
available_memory_one_worker,
) )
) )
@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
_report_kv_cache_config(vllm_config, kv_cache_config) _report_kv_cache_config(vllm_config, kv_cache_config)
return kv_cache_configs return kv_cache_configs
class BlockHashListWithBlockSize:
"""
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
Used when KV cache groups have different block sizes: `hash_block_size`
is the size used to compute the original `block_hashes`; `target_block_size`
is the group's actual block size.
Currently, only scaling up by an integer factor is supported (i.e.,
`target_block_size` is a multiple of `hash_block_size`). Conversion is
performed lazily on access for efficiency, by concatenating consecutive
hashes at `hash_block_size` to form each hash at `target_block_size`.
Example (`hash_block_size` = 16, `target_block_size` = 32):
concatenating two 16-size hashes yields one 32-size hash:
Block hashes with block_size 16:
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|-------------|------|-------|-------|-------|
| Hash | A | B | C | D |
Block hashes with block_size 32:
| Token Range | 0-31 | 32-63 |
|-------------|------|-------|
| Hash | AB | CD |
Args:
block_hashes: Block hashes to convert, computed at `hash_block_size`.
hash_block_size: Block size at which `block_hashes` were computed.
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
"""
def __init__(
self,
block_hashes: list[BlockHash],
hash_block_size: int,
target_block_size: int,
):
self.block_hashes = block_hashes
assert target_block_size % hash_block_size == 0
self.scale_factor = target_block_size // hash_block_size
def __len__(self) -> int:
return len(self.block_hashes) // self.scale_factor
@overload
def __getitem__(self, idx: int) -> BlockHash: ...
@overload
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
def __getitem__(self, idx):
if isinstance(idx, int):
return self._get_value_at(idx)
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
return [self._get_value_at(i) for i in range(start, stop, step)]
raise TypeError(f"Invalid index type: {type(idx)!r}")
def __iter__(self) -> Iterator[BlockHash]:
for i in range(len(self)):
yield self._get_value_at(i)
def _get_value_at(self, idx: int) -> BlockHash:
base = idx * self.scale_factor
end = base + self.scale_factor
merged_hash: bytes = self.block_hashes[base]
for i in range(base + 1, end):
merged_hash += self.block_hashes[i]
return BlockHash(merged_hash)
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize

View File

@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size, pcp_world_size=self.pcp_world_size,
hash_block_size=self.block_size,
) )
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0) sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
if sink_len > 0: if sink_len > 0:
@ -1093,8 +1094,6 @@ class Scheduler(SchedulerInterface):
and request.sampling_params.logprobs is not None and request.sampling_params.logprobs is not None
and logprobs and logprobs
): ):
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and self.structured_output_manager.should_advance(request): if new_token_ids and self.structured_output_manager.should_advance(request):

View File

@ -7,7 +7,7 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
CrossAttentionSpec, CrossAttentionSpec,
@ -208,12 +208,13 @@ class SingleTypeKVCacheManager(ABC):
@abstractmethod @abstractmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -233,6 +234,11 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool. block_pool: The block pool.
kv_cache_spec: The kv cache spec. kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle. use_eagle: Whether to use eagle.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens). By default, it should
be set to the block_size.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
Returns: Returns:
A list of cached blocks with skipped blocks replaced by null block A list of cached blocks with skipped blocks replaced by null block
@ -300,18 +306,19 @@ class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, kv_cache_spec,
(FullAttentionSpec, FullDiffkvAttentionSpec, ChunkedLocalAttentionSpec), FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec
), ( ), (
"FullAttentionManager can only be used for full attention " "FullAttentionManager can only be used for full attention "
"and chunked local attention groups" "and chunked local attention groups"
@ -335,6 +342,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
else: else:
break break
if use_eagle and computed_blocks[0]: if use_eagle and computed_blocks[0]:
# Need to drop the last matched block if eagle is enabled.
for computed in computed_blocks:
computed.pop()
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks: for computed in computed_blocks:
computed.pop() computed.pop()
return computed_blocks return computed_blocks
@ -361,12 +375,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -398,6 +413,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
[block_pool.null_block] * max_num_blocks [block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids)) for _ in range(len(kv_cache_group_ids))
) )
block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0 num_contiguous_blocks = 0
match_found = False match_found = False
# Search from right to left and early stop when a match is found. # Search from right to left and early stop when a match is found.
@ -405,6 +421,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if cached_block := block_pool.get_cached_block( if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids block_hashes[i], kv_cache_group_ids
): ):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if (
num_contiguous_blocks == 0
and block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
# Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block): for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached computed[i] = cached
num_contiguous_blocks += 1 num_contiguous_blocks += 1
@ -423,7 +448,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# `num_contiguous_blocks < sliding_window_contiguous_blocks`. # `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks: for computed in computed_blocks:
del computed[num_contiguous_blocks:] del computed[num_contiguous_blocks:]
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]: if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, (
"aligned_length is not compatible with eagle now"
)
for computed in computed_blocks: for computed in computed_blocks:
computed.pop() computed.pop()
return computed_blocks return computed_blocks
@ -477,12 +511,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -513,6 +548,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: The block pool. block_pool: The block pool.
kv_cache_spec: The kv cache spec. kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle. use_eagle: Whether to use eagle.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens).
Returns: Returns:
A list of cached blocks A list of cached blocks
@ -526,6 +565,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
) )
assert dcp_world_size == 1, "DCP not support chunked local attn now." assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now." assert pcp_world_size == 1, "PCP not support chunked local attn now."
assert kv_cache_spec.block_size == alignment_tokens, (
"KV cache groups with different block sizes are not compatible with "
"chunked local attention now"
)
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0: if max_length > 0:
local_attention_start_idx = ( local_attention_start_idx = (
@ -614,12 +657,13 @@ class MambaManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -632,12 +676,21 @@ class MambaManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids)) [] for _ in range(len(kv_cache_group_ids))
) )
max_num_blocks = max_length // kv_cache_spec.block_size block_size = kv_cache_spec.block_size
max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found. # Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1): for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block( if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids block_hashes[i], kv_cache_group_ids
): ):
# When enable Mamba prefix caching, `block_size` will be aligned
# across full attention layers and Mamba layers to ensure the
# prefix hit length aligned at block
if (
block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
for computed, cached in zip(computed_blocks, cached_block): for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes: # the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0]) # hit_length = len(hit_blocks_other_attn[0])
@ -710,12 +763,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:

View File

@ -31,7 +31,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.utils.func_utils import deprecate_kwargs
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
@ -195,12 +194,6 @@ class AsyncLLM(EngineClient):
self.profiler = None self.profiler = None
@classmethod @classmethod
@deprecate_kwargs(
"disable_log_requests",
additional_message=(
"This argument will have no effect. Use `enable_log_requests` instead."
),
)
def from_vllm_config( def from_vllm_config(
cls, cls,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@ -213,7 +206,6 @@ class AsyncLLM(EngineClient):
client_addresses: dict[str, str] | None = None, client_addresses: dict[str, str] | None = None,
client_count: int = 1, client_count: int = 1,
client_index: int = 0, client_index: int = 0,
disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM": ) -> "AsyncLLM":
# Create the LLMEngine. # Create the LLMEngine.
return cls( return cls(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from dataclasses import replace from dataclasses import replace
import torch import torch
@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
def parse_output( def parse_output(
output_token_ids: torch.Tensor, output_token_ids: torch.Tensor,
vocab_size: int, vocab_size: int,
) -> list[list[int]]: discard_req_indices: Sequence[int] = (),
return_cu_num_tokens: bool = False,
) -> tuple[list[list[int]], list[int] | None]:
"""Parse the output of the rejection sampler. """Parse the output of the rejection sampler.
Args: Args:
output_token_ids: The sampled token IDs in shape output_token_ids: The sampled token IDs in shape
@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function. and will be filtered out in this function.
vocab_size: The size of the vocabulary. vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts.
Returns: Returns:
A list of lists of token IDs. A list of lists of token IDs.
""" """
@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size output_token_ids_np < vocab_size
) )
cu_num_tokens = None
if return_cu_num_tokens:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False
outputs = [ outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
] ]
return outputs return outputs, cu_num_tokens
def apply_logits_processors( def apply_logits_processors(
self, self,

View File

@ -1055,11 +1055,11 @@ class EagleProposer:
elif ( elif (
isinstance(target_embed_tokens.weight, torch.Tensor) isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor) and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
and torch.allclose( # TODO: Offload to CPU for comparison to avoid extra GPU memory
# usage in CI testing environments with limited GPU memory
and torch.equal(
target_embed_tokens.weight.cpu(), target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(), self.model.model.embed_tokens.weight.cpu(),
rtol=1e-5,
atol=1e-7,
) )
): ):
share_embeddings = True share_embeddings = True
@ -1105,8 +1105,11 @@ class EagleProposer:
hasattr(target_language_model, "lm_head") hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor) and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor) and isinstance(self.model.lm_head.weight, torch.Tensor)
# TODO: Offload to CPU for comparison to avoid extra GPU memory
# usage in CI testing environments with limited GPU memory
and torch.equal( and torch.equal(
target_language_model.lm_head.weight, self.model.lm_head.weight target_language_model.lm_head.weight.cpu(),
self.model.lm_head.weight.cpu(),
) )
): ):
share_lm_head = True share_lm_head = True

View File

@ -186,7 +186,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self, self,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor, sampled_token_ids: torch.Tensor,
logprobs_tensors: torch.Tensor | None, logprobs_tensors: LogprobsTensors | None,
invalid_req_indices: list[int], invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream, async_output_copy_stream: torch.cuda.Stream,
vocab_size: int, vocab_size: int,
@ -222,28 +222,29 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
This function blocks until the copy is finished. This function blocks until the copy is finished.
""" """
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
self.async_copy_ready_event.synchronize() self.async_copy_ready_event.synchronize()
# Release the device tensors once the copy has completed. # Release the device tensors once the copy has completed.
del self._logprobs_tensors del self._logprobs_tensors
del self._sampled_token_ids del self._sampled_token_ids
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1: if max_gen_len == 1:
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
cu_num_tokens = None
else: else:
valid_sampled_token_ids = RejectionSampler.parse_output( valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
self.sampled_token_ids_cpu, self.sampled_token_ids_cpu,
self.vocab_size, self.vocab_size,
self._invalid_req_indices,
return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
) )
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
output = self._model_runner_output output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu: if self._logprobs_tensors_cpu:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
# for async sched + spec decode + logprobs compatibility.
output.logprobs = self._logprobs_tensors_cpu.tolists()
return output return output
@ -2629,28 +2630,24 @@ class GPUModelRunner(
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = [] invalid_req_indices = []
cu_num_new_tokens: list[int] | None = None cu_num_tokens: list[int] | None = None
if not self.use_async_scheduling: if not self.use_async_scheduling:
# Get the valid generated tokens. # Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1] max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1: if max_gen_len == 1:
# No spec decode tokens. # No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids) valid_sampled_token_ids = self._to_list(sampled_token_ids)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else: else:
# Includes spec decode tokens. # Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output( valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
sampled_token_ids, sampled_token_ids,
self.input_batch.vocab_size, self.input_batch.vocab_size,
discard_sampled_tokens_req_indices,
return_cu_num_tokens=logprobs_tensors is not None,
) )
if logprobs_tensors:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens = [0]
for toks in valid_sampled_token_ids:
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else: else:
valid_sampled_token_ids = [] valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
@ -2704,7 +2701,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = ( logprobs_lists = (
logprobs_tensors.tolists(cu_num_new_tokens) logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None if not self.use_async_scheduling and logprobs_tensors is not None
else None else None
) )
@ -2824,7 +2821,7 @@ class GPUModelRunner(
# returns True. before returning early here we call # returns True. before returning early here we call
# dummy run to ensure coordinate_batch_across_dp # dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues. # is called into to avoid out of sync issues.
self._dummy_run(1) self._dummy_run(self._get_num_input_tokens(1))
if not has_kv_transfer_group(): if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do. # Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
@ -3495,6 +3492,10 @@ class GPUModelRunner(
scope="local", scope="local",
) )
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None)
):
prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model()) supports_multimodal_pruning(self.get_model())
@ -4277,14 +4278,18 @@ class GPUModelRunner(
# NOTE: This happens when encoder cache needs to store # NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto. # the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size # In this case we create dummy embeddings of size
# (encode_budget, hidden_size) and scatter encoder # (max_tokens_for_modality, hidden_size) and scatter
# output into it. # encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape encoder_output_shape = dummy_encoder_outputs[0].shape
if encoder_output_shape[0] < encoder_budget: max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = [] expanded_outputs = []
for output in dummy_encoder_outputs: for output in dummy_encoder_outputs:
expanded = output.new_zeros( expanded = output.new_zeros(
(encoder_budget, encoder_output_shape[-1]) (max_mm_tokens_per_item, encoder_hidden_size)
) )
num_tokens = output.shape[0] num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output) expanded[:num_tokens].copy_(output)