mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 15:21:25 +08:00
Merge branch 'main' into imarkov/eplb_optimizations
This commit is contained in:
commit
b8533148ed
@ -903,11 +903,12 @@ steps:
|
|||||||
- label: Transformers Nightly Models Test
|
- label: Transformers Nightly Models Test
|
||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
optional: true
|
optional: true
|
||||||
|
soft_fail: true
|
||||||
commands:
|
commands:
|
||||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||||
- pytest -v -s tests/models/test_initialization.py -k 'not (Ultravox or Phi4Multimodal or MiniCPMO or Lfm2Moe or RobertaForSequenceClassification or Ovis2_5 or DeepseekOCR or KimiVL)'
|
- pytest -v -s tests/models/test_initialization.py
|
||||||
- pytest -v -s tests/models/test_transformers.py
|
- pytest -v -s tests/models/test_transformers.py
|
||||||
# - pytest -v -s tests/models/multimodal/processing/
|
- pytest -v -s tests/models/multimodal/processing/
|
||||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||||
- python3 examples/offline_inference/basic/chat.py
|
- python3 examples/offline_inference/basic/chat.py
|
||||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||||
|
|||||||
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||||
|
|||||||
2
.github/workflows/macos-smoke-test.yml
vendored
2
.github/workflows/macos-smoke-test.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
|||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@v7
|
- uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
|
|||||||
@ -604,12 +604,15 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
|
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||||
|
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
||||||
|
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${FP4_ARCHS}")
|
CUDA_ARCHS "${FP4_ARCHS}")
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
||||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||||
else()
|
else()
|
||||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||||
|
|||||||
@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
|
|||||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||||
const float* prefix_lse, const scalar_t* suffix_output,
|
const float* prefix_lse, const scalar_t* suffix_output,
|
||||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||||
const uint head_size) {
|
const uint head_size, const uint prefix_head_stride,
|
||||||
|
const uint output_head_stride) {
|
||||||
using pack_128b_t = uint4;
|
using pack_128b_t = uint4;
|
||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
const uint threads_per_head = head_size / pack_size;
|
const uint threads_per_head = head_size / pack_size;
|
||||||
@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
|
|||||||
const uint head_idx = token_head_idx % num_heads;
|
const uint head_idx = token_head_idx % num_heads;
|
||||||
|
|
||||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||||
const uint head_offset =
|
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
|
||||||
token_idx * num_heads * head_size + head_idx * head_size;
|
head_idx * prefix_head_stride;
|
||||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
|
||||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
head_idx * output_head_stride;
|
||||||
scalar_t* output_head_ptr = output + head_offset;
|
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||||
|
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||||
|
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||||
|
|
||||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||||
@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
|
|||||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||||
num_heads, head_size); \
|
num_heads, head_size, prefix_head_stride, output_head_stride); \
|
||||||
}
|
}
|
||||||
|
|
||||||
/*@brief Merges the attention states from prefix and suffix
|
/*@brief Merges the attention states from prefix and suffix
|
||||||
@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
|||||||
const uint num_tokens = output.size(0);
|
const uint num_tokens = output.size(0);
|
||||||
const uint num_heads = output.size(1);
|
const uint num_heads = output.size(1);
|
||||||
const uint head_size = output.size(2);
|
const uint head_size = output.size(2);
|
||||||
|
const uint prefix_head_stride = prefix_output.stride(1);
|
||||||
|
const uint output_head_stride = output.stride(1);
|
||||||
const uint pack_size = 16 / sizeof(scalar_t);
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
TORCH_CHECK(head_size % pack_size == 0,
|
TORCH_CHECK(head_size % pack_size == 0,
|
||||||
"headsize must be multiple of pack_size:", pack_size);
|
"headsize must be multiple of pack_size:", pack_size);
|
||||||
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
|
||||||
"output heads must be contiguous in memory");
|
|
||||||
TORCH_CHECK(
|
|
||||||
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
|
||||||
"prefix_output heads must be contiguous in memory");
|
|
||||||
TORCH_CHECK(
|
|
||||||
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
|
||||||
"suffix_output heads must be contiguous in memory");
|
|
||||||
float* output_lse_ptr = nullptr;
|
float* output_lse_ptr = nullptr;
|
||||||
if (output_lse.has_value()) {
|
if (output_lse.has_value()) {
|
||||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||||
|
|||||||
@ -52,14 +52,13 @@ void paged_attention_v2(
|
|||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
void merge_attn_states(torch::Tensor& output,
|
void merge_attn_states(torch::Tensor& output,
|
||||||
std::optional<torch::Tensor> output_lse,
|
std::optional<torch::Tensor> output_lse,
|
||||||
const torch::Tensor& prefix_output,
|
const torch::Tensor& prefix_output,
|
||||||
const torch::Tensor& prefix_lse,
|
const torch::Tensor& prefix_lse,
|
||||||
const torch::Tensor& suffix_output,
|
const torch::Tensor& suffix_output,
|
||||||
const torch::Tensor& suffix_lse);
|
const torch::Tensor& suffix_lse);
|
||||||
|
#ifndef USE_ROCM
|
||||||
void convert_vertical_slash_indexes(
|
void convert_vertical_slash_indexes(
|
||||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||||
|
|||||||
@ -22,6 +22,7 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
#include "cutlass/tensor_ref.h"
|
#include "cutlass/tensor_ref.h"
|
||||||
@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void run_fp4_blockwise_scaled_group_mm(
|
void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||||
@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm(
|
|||||||
|
|
||||||
auto can_implement_status = gemm_op.can_implement(args);
|
auto can_implement_status = gemm_op.can_implement(args);
|
||||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||||
"Failed to implement GEMM");
|
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||||
|
|
||||||
// Run the GEMM
|
// Run the GEMM
|
||||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||||
|
"Failed to initialize GEMM: status=", (int)status,
|
||||||
|
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
||||||
|
" M=", M, " N=", N, " K=", K);
|
||||||
|
|
||||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||||
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||||
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||||
|
int N, int K) {
|
||||||
|
using ProblemShape =
|
||||||
|
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||||
|
using ElementType = cutlass::float_e2m1_t;
|
||||||
|
using ElementSFType = cutlass::float_ue4m3_t;
|
||||||
|
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||||
|
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||||
|
|
||||||
|
// NOTE: For SM120 it seems templating the output type is not supported and
|
||||||
|
// we need to hardcode the output type to bfloat16
|
||||||
|
using ElementC = cutlass::bfloat16_t;
|
||||||
|
using ElementD = ElementC;
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
// Layout definitions
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD = LayoutC;
|
||||||
|
|
||||||
|
// Alignment constraints
|
||||||
|
static constexpr int AlignmentA = 32;
|
||||||
|
static constexpr int AlignmentB = 32;
|
||||||
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||||
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
// Architecture definitions
|
||||||
|
using ArchTag = cutlass::arch::Sm120;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||||
|
|
||||||
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
|
using MmaTileShape = Shape<_128, _128, _128>;
|
||||||
|
|
||||||
|
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
|
||||||
|
ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, MmaTileShape, ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||||
|
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||||
|
LayoutD*, AlignmentD,
|
||||||
|
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||||
|
FusionOperation>::CollectiveOp;
|
||||||
|
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
|
||||||
|
LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
|
||||||
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||||
|
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
||||||
|
|
||||||
|
using GemmKernel =
|
||||||
|
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||||
|
CollectiveEpilogue>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||||
|
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||||
|
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||||
|
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||||
|
|
||||||
|
using LayoutSFA =
|
||||||
|
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
||||||
|
using LayoutSFB =
|
||||||
|
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
||||||
|
using ScaleConfig =
|
||||||
|
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||||
|
|
||||||
|
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
auto options_int =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
|
|
||||||
|
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
||||||
|
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
||||||
|
torch::Tensor c_strides1 =
|
||||||
|
torch::full({num_experts}, output.stride(0), options_int);
|
||||||
|
torch::Tensor a_strides1 =
|
||||||
|
torch::full({num_experts}, a.stride(0) * 2, options_int);
|
||||||
|
torch::Tensor b_strides1 =
|
||||||
|
torch::full({num_experts}, b.stride(1) * 2, options_int);
|
||||||
|
|
||||||
|
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||||
|
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||||
|
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
|
||||||
|
expert_offsets, sf_offsets, problem_sizes, M, N, K);
|
||||||
|
|
||||||
|
// Create an instance of the GEMM
|
||||||
|
Gemm gemm_op;
|
||||||
|
|
||||||
|
// Initialize problem_sizes_as_shapes correctly
|
||||||
|
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||||
|
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||||
|
|
||||||
|
// Set the Scheduler info
|
||||||
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||||
|
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||||
|
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||||
|
hw_info.device_id = a.get_device();
|
||||||
|
static std::unordered_map<int, int> cached_sm_counts;
|
||||||
|
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||||
|
cached_sm_counts[hw_info.device_id] =
|
||||||
|
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||||
|
hw_info.device_id);
|
||||||
|
}
|
||||||
|
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
|
||||||
|
|
||||||
|
// Mainloop Arguments
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{
|
||||||
|
static_cast<const ElementType**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideA*>(a_strides1.data_ptr()),
|
||||||
|
static_cast<const ElementType**>(b_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideB*>(b_strides1.data_ptr()),
|
||||||
|
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
||||||
|
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
|
||||||
|
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
|
||||||
|
|
||||||
|
// Epilogue Arguments
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
{}, // epilogue.thread
|
||||||
|
nullptr,
|
||||||
|
static_cast<StrideC*>(c_strides1.data_ptr()),
|
||||||
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideC*>(c_strides1.data_ptr())};
|
||||||
|
auto& fusion_args = epilogue_args.thread;
|
||||||
|
fusion_args.alpha_ptr_array =
|
||||||
|
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
|
||||||
|
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
||||||
|
fusion_args.beta = 0.0f;
|
||||||
|
|
||||||
|
// Gemm Arguments
|
||||||
|
typename GemmKernel::Arguments args{
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||||
|
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||||
|
mainloop_args,
|
||||||
|
epilogue_args,
|
||||||
|
hw_info,
|
||||||
|
scheduler};
|
||||||
|
|
||||||
|
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||||
|
auto const workspace_options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||||
|
|
||||||
|
auto can_implement_status = gemm_op.can_implement(args);
|
||||||
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||||
|
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||||
|
|
||||||
|
// Run the GEMM
|
||||||
|
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||||
|
"Failed to initialize GEMM: status=", (int)status,
|
||||||
|
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
||||||
|
" M=", M, " N=", N, " K=", K);
|
||||||
|
|
||||||
|
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OutType>
|
||||||
|
void run_fp4_blockwise_scaled_group_mm(
|
||||||
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||||
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
||||||
|
int N, int K) {
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||||
|
if (version_num >= 120 && version_num < 130) {
|
||||||
|
run_fp4_blockwise_scaled_group_mm_sm120(
|
||||||
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||||
|
expert_offsets, sf_offsets, M, N, K);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||||
|
if (version_num >= 100 && version_num < 120) {
|
||||||
|
run_fp4_blockwise_scaled_group_mm_sm100<OutType>(
|
||||||
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||||
|
expert_offsets, sf_offsets, M, N, K);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 100 or 120");
|
||||||
|
}
|
||||||
|
|
||||||
|
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||||
#endif
|
#endif
|
||||||
@ -374,7 +583,8 @@ void cutlass_fp4_group_mm(
|
|||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
||||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||||
// Input validation
|
// Input validation
|
||||||
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
||||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||||
@ -408,6 +618,14 @@ void cutlass_fp4_group_mm(
|
|||||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||||
expert_offsets, sf_offsets, M, N, K);
|
expert_offsets, sf_offsets, M, N, K);
|
||||||
} else {
|
} else {
|
||||||
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
if (version_num >= 120 && version_num < 130) {
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
|
||||||
|
output.scalar_type());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
|
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
|
||||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||||
expert_offsets, sf_offsets, M, N, K);
|
expert_offsets, sf_offsets, M, N, K);
|
||||||
@ -416,8 +634,8 @@ void cutlass_fp4_group_mm(
|
|||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
false,
|
false,
|
||||||
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
||||||
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
|
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
|
||||||
"12.8 or above.");
|
"and CUDA 12.8 or above.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -307,7 +307,7 @@ constexpr auto FLOAT = at::ScalarType::Float;
|
|||||||
constexpr auto INT = at::ScalarType::Int;
|
constexpr auto INT = at::ScalarType::Int;
|
||||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||||
|
|
||||||
void scaled_fp4_experts_quant_sm100a(
|
void scaled_fp4_experts_quant_sm1xxa(
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
torch::Tensor& output, torch::Tensor& output_scale,
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||||
torch::Tensor const& input_offset_by_experts,
|
torch::Tensor const& input_offset_by_experts,
|
||||||
|
|||||||
@ -24,8 +24,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
|||||||
torch::Tensor const& input_sf);
|
torch::Tensor const& input_sf);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
void scaled_fp4_experts_quant_sm100a(
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
void scaled_fp4_experts_quant_sm1xxa(
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
torch::Tensor& output, torch::Tensor& output_scale,
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||||
torch::Tensor const& input_offset_by_experts,
|
torch::Tensor const& input_offset_by_experts,
|
||||||
@ -54,8 +55,9 @@ void scaled_fp4_experts_quant(
|
|||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||||
torch::Tensor const& input_offset_by_experts,
|
torch::Tensor const& input_offset_by_experts,
|
||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
torch::Tensor const& output_scale_offset_by_experts) {
|
||||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
return scaled_fp4_experts_quant_sm100a(
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
return scaled_fp4_experts_quant_sm1xxa(
|
||||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||||
output_scale_offset_by_experts);
|
output_scale_offset_by_experts);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -67,9 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
|
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
|
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
|
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||||
void get_cutlass_moe_mm_data_caller(
|
void get_cutlass_moe_mm_data_caller(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
@ -284,8 +284,9 @@ void get_cutlass_moe_mm_data(
|
|||||||
// This function currently gets compiled only if we have a valid cutlass moe
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
// mm to run it for.
|
// mm to run it for.
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||||
problem_sizes2, input_permutation,
|
problem_sizes2, input_permutation,
|
||||||
output_permutation, num_experts, n, k,
|
output_permutation, num_experts, n, k,
|
||||||
@ -296,7 +297,7 @@ void get_cutlass_moe_mm_data(
|
|||||||
false,
|
false,
|
||||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||||
"CUDA device capability: ",
|
"CUDA device capability: ",
|
||||||
version_num, ". Required capability: 90 or 100");
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_moe_mm_problem_sizes(
|
void get_cutlass_moe_mm_problem_sizes(
|
||||||
@ -304,8 +305,9 @@ void get_cutlass_moe_mm_problem_sizes(
|
|||||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||||
problem_sizes2, num_experts, n, k,
|
problem_sizes2, num_experts, n, k,
|
||||||
blockscale_offsets);
|
blockscale_offsets);
|
||||||
@ -315,7 +317,7 @@ void get_cutlass_moe_mm_problem_sizes(
|
|||||||
false,
|
false,
|
||||||
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||||
"kernel for CUDA device capability: ",
|
"kernel for CUDA device capability: ",
|
||||||
version_num, ". Required capability: 90 or 100");
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
@ -328,8 +330,9 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
|||||||
// This function currently gets compiled only if we have a valid cutlass moe
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
// mm to run it for.
|
// mm to run it for.
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||||
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||||
problem_sizes2, expert_num_tokens,
|
problem_sizes2, expert_num_tokens,
|
||||||
num_local_experts, padded_m, n, k);
|
num_local_experts, padded_m, n, k);
|
||||||
@ -339,7 +342,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
|||||||
false,
|
false,
|
||||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
||||||
"for CUDA device capability: ",
|
"for CUDA device capability: ",
|
||||||
version_num, ". Required capability: 90 or 100");
|
version_num, ". Required capability: 90, 100, or 120");
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||||
|
|||||||
@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
// Merge attn states
|
// Merge attn states
|
||||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
// can be used to combine partial attention results (in the split-KV case)
|
// can be used to combine partial attention results (in the split-KV case)
|
||||||
@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor suffix_output,"
|
" Tensor suffix_output,"
|
||||||
" Tensor suffix_lse) -> ()");
|
" Tensor suffix_lse) -> ()");
|
||||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||||
|
#ifndef USE_ROCM
|
||||||
ops.def(
|
ops.def(
|
||||||
"convert_vertical_slash_indexes("
|
"convert_vertical_slash_indexes("
|
||||||
" Tensor! block_count, Tensor! block_offset, "
|
" Tensor! block_count, Tensor! block_offset, "
|
||||||
|
|||||||
@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
|
|||||||
|
|
||||||
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
|
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
|
||||||
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
|
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
|
||||||
- [`CompressedTensorsW4A4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4MoeMethod]
|
- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod]
|
||||||
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
|
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
|
||||||
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
|
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
|
||||||
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
|
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
|
||||||
|
|||||||
@ -49,7 +49,8 @@ We currently support the following OpenAI APIs:
|
|||||||
- *Note: `suffix` parameter is not supported.*
|
- *Note: `suffix` parameter is not supported.*
|
||||||
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
|
||||||
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
- Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template).
|
||||||
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
- *Note: `user` parameter is ignored.*
|
||||||
|
- *Note:* Setting the `parallel_tool_calls` parameter to `false` ensures vLLM only returns zero or one tool call per request. Setting it to `true` (the default) allows returning more than one tool call per request. There is no guarantee more than one tool call will be returned if this is set to `true`, as that behavior is model dependent and not all models are designed to support parallel tool calls.
|
||||||
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
|
||||||
- Only applicable to [embedding models](../models/pooling_models.md).
|
- Only applicable to [embedding models](../models/pooling_models.md).
|
||||||
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
|
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
|
||||||
|
|||||||
@ -133,7 +133,7 @@ def main(args):
|
|||||||
tensor_parallel_size=args.tp,
|
tensor_parallel_size=args.tp,
|
||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
enforce_eager=args.enforce_eager,
|
enforce_eager=args.enforce_eager,
|
||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.9,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
max_model_len=args.max_model_len,
|
max_model_len=args.max_model_len,
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class SillyModel(nn.Module):
|
|||||||
def _run_simple_model(
|
def _run_simple_model(
|
||||||
splitting_ops,
|
splitting_ops,
|
||||||
use_inductor_graph_partition,
|
use_inductor_graph_partition,
|
||||||
use_inductor,
|
backend,
|
||||||
expected_num_piecewise_graphs_seen,
|
expected_num_piecewise_graphs_seen,
|
||||||
expected_num_piecewise_capturable_graphs_seen,
|
expected_num_piecewise_capturable_graphs_seen,
|
||||||
expected_num_backend_compilations,
|
expected_num_backend_compilations,
|
||||||
@ -64,7 +64,7 @@ def _run_simple_model(
|
|||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
use_inductor=use_inductor,
|
backend=backend,
|
||||||
splitting_ops=splitting_ops,
|
splitting_ops=splitting_ops,
|
||||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
cudagraph_copy_inputs=True,
|
cudagraph_copy_inputs=True,
|
||||||
@ -124,14 +124,14 @@ def _run_simple_model(
|
|||||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
@pytest.mark.parametrize("backend", ["inductor", "eager"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@create_new_process_for_each_test("spawn")
|
@create_new_process_for_each_test("spawn")
|
||||||
def test_simple_piecewise_compile(use_inductor):
|
def test_simple_piecewise_compile(backend):
|
||||||
_run_simple_model(
|
_run_simple_model(
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
use_inductor_graph_partition=False,
|
use_inductor_graph_partition=False,
|
||||||
use_inductor=use_inductor,
|
backend=backend,
|
||||||
# 2 * num_layers + 1
|
# 2 * num_layers + 1
|
||||||
expected_num_piecewise_graphs_seen=5,
|
expected_num_piecewise_graphs_seen=5,
|
||||||
# 1 + num_layers
|
# 1 + num_layers
|
||||||
@ -155,7 +155,7 @@ def test_simple_inductor_graph_partition(monkeypatch):
|
|||||||
_run_simple_model(
|
_run_simple_model(
|
||||||
splitting_ops=["silly::attention"],
|
splitting_ops=["silly::attention"],
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
use_inductor=True,
|
backend="inductor",
|
||||||
# Since not splitting at fx graph level
|
# Since not splitting at fx graph level
|
||||||
expected_num_piecewise_graphs_seen=1,
|
expected_num_piecewise_graphs_seen=1,
|
||||||
# Since not splitting at fx graph level
|
# Since not splitting at fx graph level
|
||||||
|
|||||||
@ -249,14 +249,13 @@ def test_compilation_config():
|
|||||||
args = parser.parse_args(
|
args = parser.parse_args(
|
||||||
[
|
[
|
||||||
"-O",
|
"-O",
|
||||||
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
|
||||||
'"use_inductor": false}',
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
args.compilation_config.mode == 3
|
args.compilation_config.mode == 3
|
||||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
and not args.compilation_config.use_inductor
|
and args.compilation_config.backend == "eager"
|
||||||
)
|
)
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
@ -264,13 +263,13 @@ def test_compilation_config():
|
|||||||
[
|
[
|
||||||
"--compilation-config="
|
"--compilation-config="
|
||||||
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||||
'"use_inductor": true}',
|
'"backend": "inductor"}',
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
args.compilation_config.mode == 3
|
args.compilation_config.mode == 3
|
||||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||||
and args.compilation_config.use_inductor
|
and args.compilation_config.backend == "inductor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -278,8 +277,9 @@ def test_prefix_cache_default():
|
|||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
||||||
|
# should be None by default (depends on model).
|
||||||
engine_args = EngineArgs.from_cli_args(args=args)
|
engine_args = EngineArgs.from_cli_args(args=args)
|
||||||
assert engine_args.enable_prefix_caching, "prefix caching should default to on."
|
assert engine_args.enable_prefix_caching is None
|
||||||
|
|
||||||
# with flag to turn it on.
|
# with flag to turn it on.
|
||||||
args = parser.parse_args(["--enable-prefix-caching"])
|
args = parser.parse_args(["--enable-prefix-caching"])
|
||||||
|
|||||||
0
tests/entrypoints/pooling/pooling/__init__.py
Normal file
0
tests/entrypoints/pooling/pooling/__init__.py
Normal file
0
tests/entrypoints/pooling/reward/__init__.py
Normal file
0
tests/entrypoints/pooling/reward/__init__.py
Normal file
0
tests/entrypoints/pooling/score/__init__.py
Normal file
0
tests/entrypoints/pooling/score/__init__.py
Normal file
@ -2,6 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
|
ResponseFunctionToolCallOutputItem,
|
||||||
|
)
|
||||||
from openai.types.responses.response_reasoning_item import (
|
from openai.types.responses.response_reasoning_item import (
|
||||||
Content,
|
Content,
|
||||||
ResponseReasoningItem,
|
ResponseReasoningItem,
|
||||||
@ -76,6 +79,18 @@ class TestResponsesUtils:
|
|||||||
== 'Hmm, the user has just started with a simple "Hello,"'
|
== 'Hmm, the user has just started with a simple "Hello,"'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_call_output = ResponseFunctionToolCallOutputItem(
|
||||||
|
id="temp_id",
|
||||||
|
type="function_call_output",
|
||||||
|
call_id="temp",
|
||||||
|
output="1234",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
|
||||||
|
assert formatted_item["role"] == "tool"
|
||||||
|
assert formatted_item["content"] == "1234"
|
||||||
|
assert formatted_item["tool_call_id"] == "temp"
|
||||||
|
|
||||||
item = ResponseReasoningItem(
|
item = ResponseReasoningItem(
|
||||||
id="lol",
|
id="lol",
|
||||||
summary=[],
|
summary=[],
|
||||||
|
|||||||
240
tests/models/test_gguf_download.py
Normal file
240
tests/models/test_gguf_download.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.config.load import LoadConfig
|
||||||
|
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import download_gguf
|
||||||
|
|
||||||
|
|
||||||
|
class TestGGUFDownload:
|
||||||
|
"""Test GGUF model downloading functionality."""
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||||
|
def test_download_gguf_single_file(self, mock_download):
|
||||||
|
"""Test downloading a single GGUF file."""
|
||||||
|
# Setup mock
|
||||||
|
mock_folder = "/tmp/mock_cache"
|
||||||
|
mock_download.return_value = mock_folder
|
||||||
|
|
||||||
|
# Mock glob to return a single file
|
||||||
|
with patch("glob.glob") as mock_glob:
|
||||||
|
mock_glob.side_effect = lambda pattern, **kwargs: (
|
||||||
|
[f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
|
||||||
|
)
|
||||||
|
|
||||||
|
result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
|
||||||
|
|
||||||
|
# Verify download_weights_from_hf was called with correct patterns
|
||||||
|
mock_download.assert_called_once_with(
|
||||||
|
model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
|
||||||
|
cache_dir=None,
|
||||||
|
allow_patterns=[
|
||||||
|
"*-IQ1_S.gguf",
|
||||||
|
"*-IQ1_S-*.gguf",
|
||||||
|
"*/*-IQ1_S.gguf",
|
||||||
|
"*/*-IQ1_S-*.gguf",
|
||||||
|
],
|
||||||
|
revision=None,
|
||||||
|
ignore_patterns=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result is the file path, not folder
|
||||||
|
assert result == f"{mock_folder}/model-IQ1_S.gguf"
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||||
|
def test_download_gguf_sharded_files(self, mock_download):
|
||||||
|
"""Test downloading sharded GGUF files."""
|
||||||
|
mock_folder = "/tmp/mock_cache"
|
||||||
|
mock_download.return_value = mock_folder
|
||||||
|
|
||||||
|
# Mock glob to return sharded files
|
||||||
|
with patch("glob.glob") as mock_glob:
|
||||||
|
mock_glob.side_effect = lambda pattern, **kwargs: (
|
||||||
|
[
|
||||||
|
f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
|
||||||
|
f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
|
||||||
|
]
|
||||||
|
if "Q2_K" in pattern
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
|
||||||
|
|
||||||
|
# Should return the first file after sorting
|
||||||
|
assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||||
|
def test_download_gguf_subdir(self, mock_download):
|
||||||
|
"""Test downloading GGUF files from subdirectory."""
|
||||||
|
mock_folder = "/tmp/mock_cache"
|
||||||
|
mock_download.return_value = mock_folder
|
||||||
|
|
||||||
|
with patch("glob.glob") as mock_glob:
|
||||||
|
mock_glob.side_effect = lambda pattern, **kwargs: (
|
||||||
|
[f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
|
||||||
|
if "Q2_K" in pattern or "**/*.gguf" in pattern
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
|
||||||
|
|
||||||
|
assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
|
||||||
|
@patch("glob.glob", return_value=[])
|
||||||
|
def test_download_gguf_no_files_found(self, mock_glob, mock_download):
|
||||||
|
"""Test error when no GGUF files are found."""
|
||||||
|
mock_folder = "/tmp/mock_cache"
|
||||||
|
mock_download.return_value = mock_folder
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
|
||||||
|
download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGGUFModelLoader:
|
||||||
|
"""Test GGUFModelLoader class methods."""
|
||||||
|
|
||||||
|
@patch("os.path.isfile", return_value=True)
|
||||||
|
def test_prepare_weights_local_file(self, mock_isfile):
|
||||||
|
"""Test _prepare_weights with local file."""
|
||||||
|
load_config = LoadConfig(load_format="gguf")
|
||||||
|
loader = GGUFModelLoader(load_config)
|
||||||
|
|
||||||
|
# Create a simple mock ModelConfig with only the model attribute
|
||||||
|
model_config = MagicMock()
|
||||||
|
model_config.model = "/path/to/model.gguf"
|
||||||
|
|
||||||
|
result = loader._prepare_weights(model_config)
|
||||||
|
assert result == "/path/to/model.gguf"
|
||||||
|
mock_isfile.assert_called_once_with("/path/to/model.gguf")
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
|
||||||
|
@patch("os.path.isfile", return_value=False)
|
||||||
|
def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
|
||||||
|
"""Test _prepare_weights with HTTPS URL."""
|
||||||
|
load_config = LoadConfig(load_format="gguf")
|
||||||
|
loader = GGUFModelLoader(load_config)
|
||||||
|
|
||||||
|
mock_hf_download.return_value = "/downloaded/model.gguf"
|
||||||
|
|
||||||
|
# Create a simple mock ModelConfig with only the model attribute
|
||||||
|
model_config = MagicMock()
|
||||||
|
model_config.model = "https://huggingface.co/model.gguf"
|
||||||
|
|
||||||
|
result = loader._prepare_weights(model_config)
|
||||||
|
assert result == "/downloaded/model.gguf"
|
||||||
|
mock_hf_download.assert_called_once_with(
|
||||||
|
url="https://huggingface.co/model.gguf"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
|
||||||
|
@patch("os.path.isfile", return_value=False)
|
||||||
|
def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
|
||||||
|
"""Test _prepare_weights with repo_id/filename.gguf format."""
|
||||||
|
load_config = LoadConfig(load_format="gguf")
|
||||||
|
loader = GGUFModelLoader(load_config)
|
||||||
|
|
||||||
|
mock_hf_download.return_value = "/downloaded/model.gguf"
|
||||||
|
|
||||||
|
# Create a simple mock ModelConfig with only the model attribute
|
||||||
|
model_config = MagicMock()
|
||||||
|
model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
|
||||||
|
|
||||||
|
result = loader._prepare_weights(model_config)
|
||||||
|
assert result == "/downloaded/model.gguf"
|
||||||
|
mock_hf_download.assert_called_once_with(
|
||||||
|
repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
|
||||||
|
@patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
|
||||||
|
@patch("vllm.config.model.get_config")
|
||||||
|
@patch("vllm.config.model.is_gguf", return_value=True)
|
||||||
|
@patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
|
||||||
|
@patch("os.path.isfile", return_value=False)
|
||||||
|
def test_prepare_weights_repo_quant_type(
|
||||||
|
self,
|
||||||
|
mock_isfile,
|
||||||
|
mock_download_gguf,
|
||||||
|
mock_is_gguf,
|
||||||
|
mock_get_config,
|
||||||
|
mock_file_exists,
|
||||||
|
mock_get_image_config,
|
||||||
|
):
|
||||||
|
"""Test _prepare_weights with repo_id:quant_type format."""
|
||||||
|
mock_hf_config = MagicMock()
|
||||||
|
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
|
||||||
|
|
||||||
|
class MockTextConfig:
|
||||||
|
max_position_embeddings = 4096
|
||||||
|
sliding_window = None
|
||||||
|
model_type = "qwen3"
|
||||||
|
num_attention_heads = 32
|
||||||
|
|
||||||
|
mock_text_config = MockTextConfig()
|
||||||
|
mock_hf_config.get_text_config.return_value = mock_text_config
|
||||||
|
mock_hf_config.dtype = "bfloat16"
|
||||||
|
mock_get_config.return_value = mock_hf_config
|
||||||
|
|
||||||
|
load_config = LoadConfig(load_format="gguf")
|
||||||
|
loader = GGUFModelLoader(load_config)
|
||||||
|
|
||||||
|
mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
|
||||||
|
)
|
||||||
|
result = loader._prepare_weights(model_config)
|
||||||
|
# The actual result will be the downloaded file path from mock
|
||||||
|
assert result == "/downloaded/model-IQ1_S.gguf"
|
||||||
|
mock_download_gguf.assert_called_once_with(
|
||||||
|
"unsloth/Qwen3-0.6B-GGUF",
|
||||||
|
"IQ1_S",
|
||||||
|
cache_dir=None,
|
||||||
|
revision=None,
|
||||||
|
ignore_patterns=["original/**/*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
|
||||||
|
@patch("vllm.config.model.get_config")
|
||||||
|
@patch("vllm.config.model.is_gguf", return_value=False)
|
||||||
|
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
|
||||||
|
@patch("os.path.isfile", return_value=False)
|
||||||
|
def test_prepare_weights_invalid_format(
|
||||||
|
self,
|
||||||
|
mock_isfile,
|
||||||
|
mock_check_gguf,
|
||||||
|
mock_is_gguf,
|
||||||
|
mock_get_config,
|
||||||
|
mock_get_image_config,
|
||||||
|
):
|
||||||
|
"""Test _prepare_weights with invalid format."""
|
||||||
|
mock_hf_config = MagicMock()
|
||||||
|
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
|
||||||
|
|
||||||
|
class MockTextConfig:
|
||||||
|
max_position_embeddings = 4096
|
||||||
|
sliding_window = None
|
||||||
|
model_type = "qwen3"
|
||||||
|
num_attention_heads = 32
|
||||||
|
|
||||||
|
mock_text_config = MockTextConfig()
|
||||||
|
mock_hf_config.get_text_config.return_value = mock_text_config
|
||||||
|
mock_hf_config.dtype = "bfloat16"
|
||||||
|
mock_get_config.return_value = mock_hf_config
|
||||||
|
|
||||||
|
load_config = LoadConfig(load_format="gguf")
|
||||||
|
loader = GGUFModelLoader(load_config)
|
||||||
|
|
||||||
|
# Create ModelConfig with a valid repo_id to avoid validation errors
|
||||||
|
# Then test _prepare_weights with invalid format
|
||||||
|
model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
|
||||||
|
# Manually set model to invalid format after creation
|
||||||
|
model_config.model = "invalid-format"
|
||||||
|
with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
|
||||||
|
loader._prepare_weights(model_config)
|
||||||
@ -212,3 +212,60 @@ async def test_parallel_tool_calls_with_results(
|
|||||||
assert finish_reason_count == 1
|
assert finish_reason_count == 1
|
||||||
assert len(chunks)
|
assert len(chunks)
|
||||||
assert "".join(chunks) == choice.message.content
|
assert "".join(chunks) == choice.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||||
|
"""
|
||||||
|
Ensure only one tool call is returned when parallel_tool_calls is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
models = await client.models.list()
|
||||||
|
model_name: str = models.data[0].id
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=200,
|
||||||
|
model=model_name,
|
||||||
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||||
|
logprobs=False,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_reason = chat_completion.choices[0].finish_reason
|
||||||
|
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||||
|
|
||||||
|
# make sure only 1 tool call is present
|
||||||
|
assert len(non_streamed_tool_calls) == 1
|
||||||
|
assert stop_reason == "tool_calls"
|
||||||
|
|
||||||
|
# make the same request, streaming
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=200,
|
||||||
|
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||||
|
logprobs=False,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason_count: int = 0
|
||||||
|
tool_call_id_count: int = 0
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
# if there's a finish reason make sure it's tools
|
||||||
|
if chunk.choices[0].finish_reason:
|
||||||
|
finish_reason_count += 1
|
||||||
|
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||||
|
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||||
|
tool_call = streamed_tool_calls[0]
|
||||||
|
if tool_call.id:
|
||||||
|
tool_call_id_count += 1
|
||||||
|
|
||||||
|
# make sure only 1 streaming tool call is present
|
||||||
|
assert tool_call_id_count == 1
|
||||||
|
assert finish_reason_count == 1
|
||||||
|
|||||||
@ -1,11 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from vllm.transformers_utils.utils import (
|
from vllm.transformers_utils.utils import (
|
||||||
is_cloud_storage,
|
is_cloud_storage,
|
||||||
is_gcs,
|
is_gcs,
|
||||||
|
is_gguf,
|
||||||
|
is_remote_gguf,
|
||||||
is_s3,
|
is_s3,
|
||||||
|
split_remote_gguf,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -28,3 +34,143 @@ def test_is_cloud_storage():
|
|||||||
assert is_cloud_storage("s3://model-path/path-to-model")
|
assert is_cloud_storage("s3://model-path/path-to-model")
|
||||||
assert not is_cloud_storage("/unix/local/path")
|
assert not is_cloud_storage("/unix/local/path")
|
||||||
assert not is_cloud_storage("nfs://nfs-fqdn.local")
|
assert not is_cloud_storage("nfs://nfs-fqdn.local")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsRemoteGGUF:
|
||||||
|
"""Test is_remote_gguf utility function."""
|
||||||
|
|
||||||
|
def test_is_remote_gguf_with_colon_and_slash(self):
|
||||||
|
"""Test is_remote_gguf with repo_id:quant_type format."""
|
||||||
|
# Valid quant types
|
||||||
|
assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||||
|
assert is_remote_gguf("user/repo:Q2_K")
|
||||||
|
assert is_remote_gguf("repo/model:Q4_K")
|
||||||
|
assert is_remote_gguf("repo/model:Q8_0")
|
||||||
|
|
||||||
|
# Invalid quant types should return False
|
||||||
|
assert not is_remote_gguf("repo/model:quant")
|
||||||
|
assert not is_remote_gguf("repo/model:INVALID")
|
||||||
|
assert not is_remote_gguf("repo/model:invalid_type")
|
||||||
|
|
||||||
|
def test_is_remote_gguf_without_colon(self):
|
||||||
|
"""Test is_remote_gguf without colon."""
|
||||||
|
assert not is_remote_gguf("repo/model")
|
||||||
|
assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
|
||||||
|
|
||||||
|
def test_is_remote_gguf_without_slash(self):
|
||||||
|
"""Test is_remote_gguf without slash."""
|
||||||
|
assert not is_remote_gguf("model.gguf")
|
||||||
|
# Even with valid quant_type, no slash means not remote GGUF
|
||||||
|
assert not is_remote_gguf("model:IQ1_S")
|
||||||
|
assert not is_remote_gguf("model:quant")
|
||||||
|
|
||||||
|
def test_is_remote_gguf_local_path(self):
|
||||||
|
"""Test is_remote_gguf with local file path."""
|
||||||
|
assert not is_remote_gguf("/path/to/model.gguf")
|
||||||
|
assert not is_remote_gguf("./model.gguf")
|
||||||
|
|
||||||
|
def test_is_remote_gguf_with_path_object(self):
|
||||||
|
"""Test is_remote_gguf with Path object."""
|
||||||
|
assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
|
||||||
|
assert not is_remote_gguf(Path("repo/model"))
|
||||||
|
|
||||||
|
def test_is_remote_gguf_with_http_https(self):
|
||||||
|
"""Test is_remote_gguf with HTTP/HTTPS URLs."""
|
||||||
|
# HTTP/HTTPS URLs should return False even with valid quant_type
|
||||||
|
assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
|
||||||
|
assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
|
||||||
|
assert not is_remote_gguf("http://repo/model:Q4_K")
|
||||||
|
assert not is_remote_gguf("https://repo/model:Q8_0")
|
||||||
|
|
||||||
|
def test_is_remote_gguf_with_cloud_storage(self):
|
||||||
|
"""Test is_remote_gguf with cloud storage paths."""
|
||||||
|
# Cloud storage paths should return False even with valid quant_type
|
||||||
|
assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
|
||||||
|
assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
|
||||||
|
assert not is_remote_gguf("s3://repo/model:Q4_K")
|
||||||
|
assert not is_remote_gguf("gs://repo/model:Q8_0")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitRemoteGGUF:
|
||||||
|
"""Test split_remote_gguf utility function."""
|
||||||
|
|
||||||
|
def test_split_remote_gguf_valid(self):
|
||||||
|
"""Test split_remote_gguf with valid repo_id:quant_type format."""
|
||||||
|
repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||||
|
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
|
||||||
|
assert quant_type == "IQ1_S"
|
||||||
|
|
||||||
|
repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
|
||||||
|
assert repo_id == "repo/model"
|
||||||
|
assert quant_type == "Q2_K"
|
||||||
|
|
||||||
|
def test_split_remote_gguf_with_path_object(self):
|
||||||
|
"""Test split_remote_gguf with Path object."""
|
||||||
|
repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
|
||||||
|
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
|
||||||
|
assert quant_type == "IQ1_S"
|
||||||
|
|
||||||
|
def test_split_remote_gguf_invalid(self):
|
||||||
|
"""Test split_remote_gguf with invalid format."""
|
||||||
|
# Invalid format (no colon) - is_remote_gguf returns False
|
||||||
|
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||||
|
split_remote_gguf("repo/model")
|
||||||
|
|
||||||
|
# Invalid quant type - is_remote_gguf returns False
|
||||||
|
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||||
|
split_remote_gguf("repo/model:INVALID_TYPE")
|
||||||
|
|
||||||
|
# HTTP URL - is_remote_gguf returns False
|
||||||
|
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||||
|
split_remote_gguf("http://repo/model:IQ1_S")
|
||||||
|
|
||||||
|
# Cloud storage - is_remote_gguf returns False
|
||||||
|
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||||
|
split_remote_gguf("s3://bucket/repo/model:Q2_K")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsGGUF:
|
||||||
|
"""Test is_gguf utility function."""
|
||||||
|
|
||||||
|
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
|
||||||
|
def test_is_gguf_with_local_file(self, mock_check_gguf):
|
||||||
|
"""Test is_gguf with local GGUF file."""
|
||||||
|
assert is_gguf("/path/to/model.gguf")
|
||||||
|
assert is_gguf("./model.gguf")
|
||||||
|
|
||||||
|
def test_is_gguf_with_remote_gguf(self):
|
||||||
|
"""Test is_gguf with remote GGUF format."""
|
||||||
|
# Valid remote GGUF format (repo_id:quant_type with valid quant_type)
|
||||||
|
assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||||
|
assert is_gguf("repo/model:Q2_K")
|
||||||
|
assert is_gguf("repo/model:Q4_K")
|
||||||
|
|
||||||
|
# Invalid quant_type should return False
|
||||||
|
assert not is_gguf("repo/model:quant")
|
||||||
|
assert not is_gguf("repo/model:INVALID")
|
||||||
|
|
||||||
|
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
|
||||||
|
def test_is_gguf_false(self, mock_check_gguf):
|
||||||
|
"""Test is_gguf returns False for non-GGUF models."""
|
||||||
|
assert not is_gguf("unsloth/Qwen3-0.6B")
|
||||||
|
assert not is_gguf("repo/model")
|
||||||
|
assert not is_gguf("model")
|
||||||
|
|
||||||
|
def test_is_gguf_edge_cases(self):
|
||||||
|
"""Test is_gguf with edge cases."""
|
||||||
|
# Empty string
|
||||||
|
assert not is_gguf("")
|
||||||
|
|
||||||
|
# Only colon, no slash (even with valid quant_type)
|
||||||
|
assert not is_gguf("model:IQ1_S")
|
||||||
|
|
||||||
|
# Only slash, no colon
|
||||||
|
assert not is_gguf("repo/model")
|
||||||
|
|
||||||
|
# HTTP/HTTPS URLs
|
||||||
|
assert not is_gguf("http://repo/model:IQ1_S")
|
||||||
|
assert not is_gguf("https://repo/model:Q2_K")
|
||||||
|
|
||||||
|
# Cloud storage
|
||||||
|
assert not is_gguf("s3://bucket/repo/model:IQ1_S")
|
||||||
|
assert not is_gguf("gs://bucket/repo/model:Q2_K")
|
||||||
|
|||||||
@ -166,7 +166,7 @@ def test_dict_args(parser):
|
|||||||
"--hf-overrides.key2.key4",
|
"--hf-overrides.key2.key4",
|
||||||
"val3",
|
"val3",
|
||||||
# Test compile config and compilation mode
|
# Test compile config and compilation mode
|
||||||
"-O.use_inductor=true",
|
"-O.use_inductor_graph_partition=true",
|
||||||
"-O.backend",
|
"-O.backend",
|
||||||
"custom",
|
"custom",
|
||||||
"-O1",
|
"-O1",
|
||||||
@ -219,7 +219,7 @@ def test_dict_args(parser):
|
|||||||
}
|
}
|
||||||
assert parsed_args.compilation_config == {
|
assert parsed_args.compilation_config == {
|
||||||
"mode": 1,
|
"mode": 1,
|
||||||
"use_inductor": True,
|
"use_inductor_graph_partition": True,
|
||||||
"backend": "custom",
|
"backend": "custom",
|
||||||
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test case 1: Requires additional lookahead tokens
|
# Test case 1: Requires additional lookahead tokens
|
||||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||||
|
)
|
||||||
blocks = kv_cache_manager.allocate_slots(
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens=3,
|
num_new_tokens=3,
|
||||||
@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
|
|||||||
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
||||||
|
|
||||||
# Test case 2: With precomputed blocks
|
# Test case 2: With precomputed blocks
|
||||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||||
|
)
|
||||||
# required_blocks = ceil((3 + 2) /4) = 2
|
# required_blocks = ceil((3 + 2) /4) = 2
|
||||||
blocks = kv_cache_manager.allocate_slots(
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
|
|||||||
|
|
||||||
# Test case 3: With precomputed blocks
|
# Test case 3: With precomputed blocks
|
||||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||||
|
)
|
||||||
blocks = kv_cache_manager.allocate_slots(
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens=3,
|
num_new_tokens=3,
|
||||||
@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# different hidden size
|
|
||||||
|
# different hidden size but same type, use UniformTypeKVCacheSpecs
|
||||||
kv_cache_specs_hybrid = {
|
kv_cache_specs_hybrid = {
|
||||||
"layer_1": new_kv_cache_spec(head_size=128),
|
"layer_1": new_kv_cache_spec(head_size=128),
|
||||||
"layer_2": new_kv_cache_spec(head_size=64),
|
"layer_2": new_kv_cache_spec(head_size=64),
|
||||||
@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Different hidden size and different type, align by different block size
|
||||||
|
kv_cache_specs_hybrid = {
|
||||||
|
"layer_1": new_kv_cache_spec(head_size=64),
|
||||||
|
"layer_2": new_sliding_window_spec(head_size=32),
|
||||||
|
}
|
||||||
|
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||||
|
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
|
||||||
|
)[0]
|
||||||
|
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||||
|
num_blocks=32,
|
||||||
|
kv_cache_tensors=[
|
||||||
|
KVCacheTensor(
|
||||||
|
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# different hidden size that cannot be aligned by using different block size
|
||||||
|
kv_cache_specs_hybrid = {
|
||||||
|
"layer_1": new_kv_cache_spec(head_size=64),
|
||||||
|
"layer_2": new_sliding_window_spec(head_size=96),
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
get_kv_cache_configs(
|
||||||
|
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
|
||||||
|
)[0]
|
||||||
|
|
||||||
# Test num_gpu_blocks_override
|
# Test num_gpu_blocks_override
|
||||||
vllm_config.cache_config.num_gpu_blocks_override = 16
|
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||||
kv_cache_config_override_blocks = get_kv_cache_configs(
|
kv_cache_config_override_blocks = get_kv_cache_configs(
|
||||||
|
|||||||
@ -134,6 +134,7 @@ def test_prefill(hash_fn):
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
|
|||||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
hash_fn = sha256
|
hash_fn = sha256
|
||||||
@ -416,6 +418,7 @@ def test_prefill_plp():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# the default hash function is sha256
|
# the default hash function is sha256
|
||||||
hash_fn = sha256
|
hash_fn = sha256
|
||||||
@ -523,6 +526,7 @@ def test_decode():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
@ -585,6 +589,7 @@ def test_evict():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_token_id = 5 * 16 + 7
|
last_token_id = 5 * 16 + 7
|
||||||
@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
|
|||||||
make_kv_cache_config(16, 2),
|
make_kv_cache_config(16, 2),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allocate 1 block and cache it.
|
# Allocate 1 block and cache it.
|
||||||
@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
make_kv_cache_config(block_size, 3),
|
make_kv_cache_config(block_size, 3),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allocate a block and cache it.
|
# Allocate a block and cache it.
|
||||||
@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
make_kv_cache_config(block_size, 5),
|
make_kv_cache_config(block_size, 5),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=False,
|
enable_caching=False,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
req1 = make_request(
|
req1 = make_request(
|
||||||
@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
|
|||||||
block_pool = BlockPool(
|
block_pool = BlockPool(
|
||||||
num_gpu_blocks=5,
|
num_gpu_blocks=5,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# Req:
|
# Req:
|
||||||
# Block 0: [0, 1, 2, 3]
|
# Block 0: [0, 1, 2, 3]
|
||||||
@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
|
|||||||
This tests that blocks are cached correctly for different kv cache groups.
|
This tests that blocks are cached correctly for different kv cache groups.
|
||||||
"""
|
"""
|
||||||
block_size = 4
|
block_size = 4
|
||||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
|
|
||||||
# Req:
|
# Req:
|
||||||
# Block 0/4: [0, 1, 2, 3]
|
# Block 0/4: [0, 1, 2, 3]
|
||||||
@ -921,6 +932,7 @@ def test_mm_prefix_caching():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
||||||
@ -1020,6 +1032,7 @@ def test_cache_key_salting():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||||
@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
# | Common-0 | Common-1 | Common-2 | ... |
|
# | Common-0 | Common-1 | Common-2 | ... |
|
||||||
@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
log_stats=False, # Disable logging stats
|
log_stats=False, # Disable logging stats
|
||||||
)
|
)
|
||||||
assert manager.prefix_cache_stats is None
|
assert manager.prefix_cache_stats is None
|
||||||
@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
|
|||||||
|
|
||||||
|
|
||||||
def test_maybe_evict_cached_block():
|
def test_maybe_evict_cached_block():
|
||||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
|
||||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||||
@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
enable_kv_cache_events=True,
|
enable_kv_cache_events=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_tokens = block_size * blocks_to_cache
|
num_tokens = block_size * blocks_to_cache
|
||||||
@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
enable_kv_cache_events=True,
|
enable_kv_cache_events=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test with LoRA request
|
# Test with LoRA request
|
||||||
@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
use_eagle=True,
|
use_eagle=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request with 3 full blocks (48 tokens)
|
# Request with 3 full blocks (48 tokens)
|
||||||
@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
use_eagle=True,
|
use_eagle=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# 2 full blocks + 5 tokens (non-divisible length)
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
token_ids = [0] * (2 * block_size + 5)
|
token_ids = [0] * (2 * block_size + 5)
|
||||||
@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
use_eagle=True,
|
use_eagle=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2 full blocks + 5 tokens (non-divisible length)
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
|
|||||||
assert num_tokens == 0
|
assert num_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_block_size():
|
||||||
|
block_size = 16
|
||||||
|
# full attention and sliding window attention layers have the same page size:
|
||||||
|
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer1"],
|
||||||
|
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer2"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
torch.float32,
|
||||||
|
sliding_window=2 * block_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = KVCacheManager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
|
||||||
|
common_token_ids = [i for i in range(10) for _ in range(block_size)]
|
||||||
|
|
||||||
|
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
assert not computed_blocks.blocks[0]
|
||||||
|
assert not computed_blocks.blocks[1]
|
||||||
|
assert num_computed_tokens == 0
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
||||||
|
)
|
||||||
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
|
||||||
|
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
assert len(computed_blocks.blocks[0]) == 3
|
||||||
|
assert len(computed_blocks.blocks[1]) == 6
|
||||||
|
assert num_computed_tokens == 6 * 16
|
||||||
|
|
||||||
|
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
|
assert len(computed_blocks.blocks[0]) == 3
|
||||||
|
assert len(computed_blocks.blocks[1]) == 6
|
||||||
|
assert num_computed_tokens == 6 * 16
|
||||||
|
|
||||||
|
# Evict some blocks to make sliding window cache hit length 5*16
|
||||||
|
# But should return 4 * 16 because full attention cache hit length must be
|
||||||
|
# a multiple of 32
|
||||||
|
manager.block_pool.cached_block_hash_to_block.pop(
|
||||||
|
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
|
||||||
|
)
|
||||||
|
manager.block_pool.cached_block_hash_to_block.pop(
|
||||||
|
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
|
||||||
|
)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
assert len(computed_blocks.blocks[0]) == 2
|
||||||
|
assert len(computed_blocks.blocks[1]) == 4
|
||||||
|
assert num_computed_tokens == 4 * 16
|
||||||
|
|
||||||
|
|
||||||
def test_block_lookup_cache_single_block_per_key():
|
def test_block_lookup_cache_single_block_per_key():
|
||||||
cache = BlockHashToBlockMap()
|
cache = BlockHashToBlockMap()
|
||||||
key0 = BlockHashWithGroupId(b"hash0")
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
|
|||||||
@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
attention_chunk_size=4,
|
attention_chunk_size=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_chunked_local_attention_manager(
|
manager = get_chunked_local_attention_manager(
|
||||||
chunked_local_attention_spec, block_pool
|
chunked_local_attention_spec, block_pool
|
||||||
)
|
)
|
||||||
@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
block_pool=block_pool,
|
block_pool=block_pool,
|
||||||
kv_cache_spec=chunked_local_attention_spec,
|
kv_cache_spec=chunked_local_attention_spec,
|
||||||
use_eagle=False,
|
use_eagle=False,
|
||||||
|
alignment_tokens=block_size,
|
||||||
)[0]
|
)[0]
|
||||||
assert len(computed_blocks) == expect_length
|
assert len(computed_blocks) == expect_length
|
||||||
|
|
||||||
@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
sliding_window=4,
|
sliding_window=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||||
|
|
||||||
def run_one_case(block_is_cached, expect_length):
|
def run_one_case(block_is_cached, expect_length):
|
||||||
@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
block_pool=block_pool,
|
block_pool=block_pool,
|
||||||
kv_cache_spec=sliding_window_spec,
|
kv_cache_spec=sliding_window_spec,
|
||||||
use_eagle=False,
|
use_eagle=False,
|
||||||
|
alignment_tokens=block_size,
|
||||||
)[0]
|
)[0]
|
||||||
assert len(computed_blocks) == expect_length
|
assert len(computed_blocks) == expect_length
|
||||||
|
|
||||||
@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
|||||||
attention_chunk_size=4,
|
attention_chunk_size=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||||
|
|
||||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||||
|
|
||||||
@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
|
|||||||
sliding_window=4,
|
sliding_window=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||||
|
|
||||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||||
|
|
||||||
@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
|
|||||||
sliding_window=4, # Placeholder value, not related to test result
|
sliding_window=4, # Placeholder value, not related to test result
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||||
@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
|||||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||||
|
|||||||
@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Set small draft model len to force doesn't-fit-in-drafter case.
|
# Set small draft model len to force doesn't-fit-in-drafter case.
|
||||||
spec_config_short = spec_config | {"max_model_len": 50}
|
spec_config_short = spec_config | {"max_model_len": 50}
|
||||||
|
|
||||||
|
test_sampling_params = [
|
||||||
|
dict(),
|
||||||
|
dict(logprobs=2),
|
||||||
|
]
|
||||||
|
|
||||||
# test_preemption, executor, async_scheduling,
|
# test_preemption, executor, async_scheduling,
|
||||||
# spec_config, test_prefill_chunking
|
# spec_config, test_prefill_chunking
|
||||||
test_configs = [
|
test_configs = [
|
||||||
@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
|||||||
(True, "uni", True, spec_config_short, True),
|
(True, "uni", True, spec_config_short, True),
|
||||||
]
|
]
|
||||||
|
|
||||||
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
|
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@dynamo_config.patch(cache_size_limit=16)
|
@dynamo_config.patch(cache_size_limit=16)
|
||||||
|
|||||||
@ -25,15 +25,6 @@ from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_max_encoder_len(vllm_config: "VllmConfig") -> int:
|
|
||||||
"""Gets the max number of encoder input tokens from the config."""
|
|
||||||
sc = vllm_config.scheduler_config
|
|
||||||
assert sc and isinstance(sc.max_num_encoder_input_tokens, int), (
|
|
||||||
"max_num_encoder_input_tokens must be int for enc-dec models"
|
|
||||||
)
|
|
||||||
return sc.max_num_encoder_input_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cross_slot_mapping(
|
def _get_cross_slot_mapping(
|
||||||
encoder_seq_lens: np.ndarray,
|
encoder_seq_lens: np.ndarray,
|
||||||
block_table_tensor: torch.Tensor,
|
block_table_tensor: torch.Tensor,
|
||||||
@ -93,23 +84,32 @@ def create_cross_attention_backend(
|
|||||||
) -> AttentionMetadata:
|
) -> AttentionMetadata:
|
||||||
new_metadata = copy(common_attn_metadata)
|
new_metadata = copy(common_attn_metadata)
|
||||||
new_metadata.causal = False
|
new_metadata.causal = False
|
||||||
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
|
||||||
new_metadata.max_seq_len = max_encoder_len
|
new_metadata.max_seq_len = max_encoder_len
|
||||||
|
# Any computed tokens indicated decode step>1 (no chunked prefill)
|
||||||
|
num_cache_decodes = (
|
||||||
|
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
|
||||||
|
)
|
||||||
|
if num_cache_decodes > 0:
|
||||||
|
# CrossAttn KV cache has already been populated on first decoder step,
|
||||||
|
# skip slot_mapping calculation for requests that do not need
|
||||||
|
# reshape_and_cache.
|
||||||
|
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
|
||||||
|
new_metadata.encoder_seq_lens_cpu = np.where(
|
||||||
|
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
|
||||||
|
)
|
||||||
|
|
||||||
new_metadata.seq_lens = torch.full(
|
# seq_lens is provided by model runner: initial encoder input length is
|
||||||
(new_metadata.num_reqs,),
|
# needed here to know how many tokens to attend to from the cached
|
||||||
max_encoder_len,
|
# cross-attention KV cache.
|
||||||
dtype=torch.int32,
|
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||||
device=self.device,
|
new_metadata.seq_lens_cpu = torch.from_numpy(
|
||||||
)
|
common_attn_metadata.encoder_seq_lens_cpu
|
||||||
new_metadata.seq_lens_cpu = torch.full(
|
|
||||||
(new_metadata.num_reqs,),
|
|
||||||
max_encoder_len,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||||
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||||
new_metadata.encoder_seq_lens,
|
new_metadata.encoder_seq_lens_cpu,
|
||||||
new_metadata.block_table_tensor,
|
new_metadata.block_table_tensor,
|
||||||
self.kv_cache_spec,
|
self.kv_cache_spec,
|
||||||
self.device,
|
self.device,
|
||||||
|
|||||||
@ -20,7 +20,11 @@ def merge_attn_states(
|
|||||||
num_query_heads = output.shape[1]
|
num_query_heads = output.shape[1]
|
||||||
head_size = output.shape[2]
|
head_size = output.shape[2]
|
||||||
padded_head_size = triton.next_power_of_2(head_size)
|
padded_head_size = triton.next_power_of_2(head_size)
|
||||||
|
# We assume the output stride on num_head is not always as same as the
|
||||||
|
# `suffix_output` and `prefix_output`, as them might be padded by the attention
|
||||||
|
# backend.
|
||||||
|
prefix_head_stride = prefix_output.stride(1)
|
||||||
|
output_head_stride = output.stride(1)
|
||||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||||
output,
|
output,
|
||||||
@ -29,6 +33,8 @@ def merge_attn_states(
|
|||||||
prefix_lse,
|
prefix_lse,
|
||||||
suffix_output,
|
suffix_output,
|
||||||
suffix_lse,
|
suffix_lse,
|
||||||
|
prefix_head_stride,
|
||||||
|
output_head_stride,
|
||||||
head_size,
|
head_size,
|
||||||
padded_head_size,
|
padded_head_size,
|
||||||
output_lse is not None,
|
output_lse is not None,
|
||||||
@ -43,6 +49,8 @@ def merge_attn_states_kernel(
|
|||||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||||
|
prefix_head_stride,
|
||||||
|
output_head_stride,
|
||||||
HEAD_SIZE: tl.constexpr,
|
HEAD_SIZE: tl.constexpr,
|
||||||
PADDED_HEAD_SIZE: tl.constexpr,
|
PADDED_HEAD_SIZE: tl.constexpr,
|
||||||
OUTPUT_LSE: tl.constexpr,
|
OUTPUT_LSE: tl.constexpr,
|
||||||
@ -79,15 +87,15 @@ def merge_attn_states_kernel(
|
|||||||
head_mask = head_arange < HEAD_SIZE
|
head_mask = head_arange < HEAD_SIZE
|
||||||
p_out = tl.load(
|
p_out = tl.load(
|
||||||
prefix_output
|
prefix_output
|
||||||
+ token_idx * num_heads * HEAD_SIZE
|
+ token_idx * num_heads * prefix_head_stride
|
||||||
+ head_idx * HEAD_SIZE
|
+ head_idx * prefix_head_stride
|
||||||
+ head_arange,
|
+ head_arange,
|
||||||
mask=head_mask,
|
mask=head_mask,
|
||||||
)
|
)
|
||||||
s_out = tl.load(
|
s_out = tl.load(
|
||||||
suffix_output
|
suffix_output
|
||||||
+ token_idx * num_heads * HEAD_SIZE
|
+ token_idx * num_heads * prefix_head_stride
|
||||||
+ head_idx * HEAD_SIZE
|
+ head_idx * prefix_head_stride
|
||||||
+ head_arange,
|
+ head_arange,
|
||||||
mask=head_mask,
|
mask=head_mask,
|
||||||
)
|
)
|
||||||
@ -99,7 +107,10 @@ def merge_attn_states_kernel(
|
|||||||
s_scale = s_se / out_se
|
s_scale = s_se / out_se
|
||||||
out = p_out * p_scale + s_out * s_scale
|
out = p_out * p_scale + s_out * s_scale
|
||||||
tl.store(
|
tl.store(
|
||||||
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
output
|
||||||
|
+ token_idx * num_heads * output_head_stride
|
||||||
|
+ head_idx * output_head_stride
|
||||||
|
+ head_arange,
|
||||||
out,
|
out,
|
||||||
mask=head_mask,
|
mask=head_mask,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import pprint
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -429,7 +430,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
self.vllm_backend.compiler_manager.compile(
|
self.vllm_backend.compiler_manager.compile(
|
||||||
submod,
|
submod,
|
||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.vllm_backend.inductor_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
graph_index=index,
|
graph_index=index,
|
||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
@ -531,6 +532,9 @@ class VllmBackend:
|
|||||||
sym_tensor_indices: list[int]
|
sym_tensor_indices: list[int]
|
||||||
input_buffers: list[torch.Tensor]
|
input_buffers: list[torch.Tensor]
|
||||||
compiler_manager: CompilerManager
|
compiler_manager: CompilerManager
|
||||||
|
# Copy of CompilationConfig.inductor_compile_config +
|
||||||
|
# an entry for PostGradPassManager
|
||||||
|
inductor_config: dict[str, Any]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -561,25 +565,30 @@ class VllmBackend:
|
|||||||
self.compilation_config
|
self.compilation_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deepcopy the inductor config to detach the post-grad custom pass
|
||||||
|
# from CompilationConfig.
|
||||||
|
# We want to avoid PostGradPassManager in CompilationConfig because
|
||||||
|
# in future we need PostGradPassManager.uuid() to be executed
|
||||||
|
# only at compile time.
|
||||||
|
self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
def configure_post_pass(self):
|
def configure_post_pass(self):
|
||||||
config = self.compilation_config
|
|
||||||
self.pass_manager.configure(self.vllm_config)
|
self.pass_manager.configure(self.vllm_config)
|
||||||
|
|
||||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||||
inductor_config = config.inductor_compile_config
|
if self.pass_key in self.inductor_config:
|
||||||
if self.pass_key in inductor_config:
|
if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
|
||||||
if isinstance(inductor_config[self.pass_key], PostGradPassManager):
|
raise ValueError(
|
||||||
# PassManager already added to config, make sure it's correct
|
"PostGradPassManager can not be kept in CompilationConfig."
|
||||||
assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
|
)
|
||||||
else:
|
else:
|
||||||
# Config should automatically wrap all inductor passes
|
# Config should automatically wrap all inductor passes
|
||||||
assert isinstance(inductor_config[self.pass_key], InductorPass)
|
assert isinstance(self.inductor_config[self.pass_key], InductorPass)
|
||||||
self.pass_manager.add(inductor_config[self.pass_key])
|
self.pass_manager.add(self.inductor_config[self.pass_key])
|
||||||
inductor_config[self.pass_key] = self.pass_manager
|
self.inductor_config[self.pass_key] = self.pass_manager
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, graph: fx.GraphModule, example_inputs
|
self, graph: fx.GraphModule, example_inputs
|
||||||
@ -638,9 +647,7 @@ class VllmBackend:
|
|||||||
self.compilation_config.local_cache_dir = local_cache_dir
|
self.compilation_config.local_cache_dir = local_cache_dir
|
||||||
|
|
||||||
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
||||||
disable_cache = not is_compile_cache_enabled(
|
disable_cache = not is_compile_cache_enabled(self.inductor_config)
|
||||||
self.compilation_config.inductor_compile_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if disable_cache:
|
if disable_cache:
|
||||||
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@ -14,6 +13,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.config.utils import hash_factors
|
from vllm.config.utils import hash_factors
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch._dynamo.aot_compile import SerializableCallable
|
from torch._dynamo.aot_compile import SerializableCallable
|
||||||
@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
|||||||
# e.g. exec(). We can't actually check these.
|
# e.g. exec(). We can't actually check these.
|
||||||
continue
|
continue
|
||||||
hash_content.append(content)
|
hash_content.append(content)
|
||||||
return hashlib.md5(
|
return safe_hash(
|
||||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
@ -16,6 +15,7 @@ import torch.fx as fx
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
|
|
||||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
factors = get_inductor_factors()
|
factors = get_inductor_factors()
|
||||||
hash_str = hashlib.md5(
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||||
str(factors).encode(), usedforsecurity=False
|
:10
|
||||||
).hexdigest()[:10]
|
]
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
|
|
||||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||||
factors = get_inductor_factors()
|
factors = get_inductor_factors()
|
||||||
hash_str = hashlib.md5(
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||||
str(factors).encode(), usedforsecurity=False
|
:10
|
||||||
).hexdigest()[:10]
|
]
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class PiecewiseBackend:
|
|||||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||||
self.graph,
|
self.graph,
|
||||||
args,
|
args,
|
||||||
self.compilation_config.inductor_compile_config,
|
self.vllm_backend.inductor_config,
|
||||||
self.compilation_config,
|
self.compilation_config,
|
||||||
graph_index=self.piecewise_compile_index,
|
graph_index=self.piecewise_compile_index,
|
||||||
num_graphs=self.total_piecewise_compiles,
|
num_graphs=self.total_piecewise_compiles,
|
||||||
|
|||||||
@ -167,8 +167,6 @@ class CacheConfig:
|
|||||||
"num_gpu_blocks_override",
|
"num_gpu_blocks_override",
|
||||||
"enable_prefix_caching",
|
"enable_prefix_caching",
|
||||||
"prefix_caching_hash_algo",
|
"prefix_caching_hash_algo",
|
||||||
# `cpu_offload_gb` does not use `torch.compile` yet.
|
|
||||||
"cpu_offload_gb",
|
|
||||||
"cpu_kvcache_space_bytes",
|
"cpu_kvcache_space_bytes",
|
||||||
"mamba_page_size_padded",
|
"mamba_page_size_padded",
|
||||||
# Post-init/derived counters
|
# Post-init/derived counters
|
||||||
|
|||||||
@ -264,7 +264,6 @@ class CompilationConfig:
|
|||||||
- [`cudagraph_copy_inputs`]
|
- [`cudagraph_copy_inputs`]
|
||||||
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
||||||
- Inductor compilation:
|
- Inductor compilation:
|
||||||
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
|
|
||||||
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
||||||
- [`inductor_compile_config`]
|
- [`inductor_compile_config`]
|
||||||
[vllm.config.CompilationConfig.inductor_compile_config]
|
[vllm.config.CompilationConfig.inductor_compile_config]
|
||||||
@ -348,7 +347,7 @@ class CompilationConfig:
|
|||||||
- 'none,+op1,+op2' to enable only op1 and op2
|
- 'none,+op1,+op2' to enable only op1 and op2
|
||||||
|
|
||||||
By default, all custom ops are enabled when running without Inductor and
|
By default, all custom ops are enabled when running without Inductor and
|
||||||
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
|
disabled when running with Inductor: mode>=VLLM_COMPILE and backend="inductor".
|
||||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||||
splitting_ops: list[str] | None = None
|
splitting_ops: list[str] | None = None
|
||||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||||
@ -374,24 +373,6 @@ class CompilationConfig:
|
|||||||
Disabled by default until more models are supported/tested to work."""
|
Disabled by default until more models are supported/tested to work."""
|
||||||
|
|
||||||
# Inductor capture
|
# Inductor capture
|
||||||
use_inductor: bool | None = None
|
|
||||||
"""
|
|
||||||
Whether to use inductor compilation.
|
|
||||||
|
|
||||||
This flag is deprecated and will be removed in the next release 0.12.0.
|
|
||||||
Please use the 'backend' option instead.
|
|
||||||
|
|
||||||
- False: inductor compilation is not used. graph runs in eager
|
|
||||||
(custom_ops enabled by default).
|
|
||||||
- True: inductor compilation is used (custom_ops disabled by default).
|
|
||||||
One graph for symbolic shape and one graph per size in compile_sizes
|
|
||||||
are compiled using configurations in inductor_compile_config.
|
|
||||||
|
|
||||||
This setting is ignored if mode<VLLM_COMPILE.
|
|
||||||
|
|
||||||
For future compatibility:
|
|
||||||
If use_inductor is True, backend="inductor" otherwise backend="eager".
|
|
||||||
"""
|
|
||||||
compile_sizes: list[int | str] | None = None
|
compile_sizes: list[int | str] | None = None
|
||||||
"""Sizes to compile for inductor. In addition
|
"""Sizes to compile for inductor. In addition
|
||||||
to integers, it also supports "cudagraph_capture_sizes" to
|
to integers, it also supports "cudagraph_capture_sizes" to
|
||||||
@ -759,14 +740,6 @@ class CompilationConfig:
|
|||||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
f"Invalid backend for piecewise compilation: {self.backend}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_inductor is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The 'use_inductor' flag is deprecated and will be "
|
|
||||||
"removed in the next release (v0.12.0). "
|
|
||||||
"Please use the 'backend' option instead.",
|
|
||||||
)
|
|
||||||
self.backend = "inductor" if self.use_inductor else "eager"
|
|
||||||
|
|
||||||
if self.backend == "":
|
if self.backend == "":
|
||||||
self.backend = current_platform.get_compile_backend()
|
self.backend = current_platform.get_compile_backend()
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
|
|||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class DeviceConfig:
|
|||||||
# the device/platform information will be summarized
|
# the device/platform information will be summarized
|
||||||
# by torch/vllm automatically.
|
# by torch/vllm automatically.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Any, Literal, get_args
|
from typing import Any, Literal, get_args
|
||||||
@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
|
|||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
KVProducer = Literal["kv_producer", "kv_both"]
|
KVProducer = Literal["kv_producer", "kv_both"]
|
||||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||||
@ -79,7 +79,7 @@ class KVTransferConfig:
|
|||||||
# no factors to consider.
|
# no factors to consider.
|
||||||
# this config will not affect the computation graph.
|
# this config will not affect the computation graph.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
|
|||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.model_loader import LoadFormats
|
from vllm.model_executor.model_loader import LoadFormats
|
||||||
@ -104,7 +104,7 @@ class LoadConfig:
|
|||||||
# no factors to consider.
|
# no factors to consider.
|
||||||
# this config will not affect the computation graph.
|
# this config will not affect the computation graph.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@field_validator("load_format", mode="after")
|
@field_validator("load_format", mode="after")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +10,7 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -74,7 +74,7 @@ class LoRAConfig:
|
|||||||
factors.append(self.fully_sharded_loras)
|
factors.append(self.fully_sharded_loras)
|
||||||
factors.append(self.lora_dtype)
|
factors.append(self.lora_dtype)
|
||||||
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import (
|
|||||||
maybe_patch_hf_config_from_gguf,
|
maybe_patch_hf_config_from_gguf,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
|
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
|
||||||
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect
|
from vllm.transformers_utils.utils import (
|
||||||
|
is_gguf,
|
||||||
|
is_remote_gguf,
|
||||||
|
maybe_model_redirect,
|
||||||
|
split_remote_gguf,
|
||||||
|
)
|
||||||
from vllm.utils.import_utils import LazyLoader
|
from vllm.utils.import_utils import LazyLoader
|
||||||
from vllm.utils.torch_utils import common_broadcastable_dtype
|
from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||||
|
|
||||||
@ -294,9 +299,6 @@ class ModelConfig:
|
|||||||
pooler_config: PoolerConfig | None = None
|
pooler_config: PoolerConfig | None = None
|
||||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||||
models."""
|
models."""
|
||||||
override_pooler_config: dict | PoolerConfig | None = None
|
|
||||||
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in
|
|
||||||
v0.12.0 or v1.0.0, whichever is sooner."""
|
|
||||||
|
|
||||||
# Multimodal config and init vars
|
# Multimodal config and init vars
|
||||||
multimodal_config: MultiModalConfig | None = None
|
multimodal_config: MultiModalConfig | None = None
|
||||||
@ -343,7 +345,6 @@ class ModelConfig:
|
|||||||
"logprobs_mode",
|
"logprobs_mode",
|
||||||
"disable_cascade_attn",
|
"disable_cascade_attn",
|
||||||
"skip_tokenizer_init",
|
"skip_tokenizer_init",
|
||||||
"enable_prompt_embeds",
|
|
||||||
"served_model_name",
|
"served_model_name",
|
||||||
"config_format",
|
"config_format",
|
||||||
"hf_token",
|
"hf_token",
|
||||||
@ -354,7 +355,6 @@ class ModelConfig:
|
|||||||
"logits_processors",
|
"logits_processors",
|
||||||
"io_processor_plugin",
|
"io_processor_plugin",
|
||||||
"pooler_config",
|
"pooler_config",
|
||||||
"override_pooler_config",
|
|
||||||
"multimodal_config",
|
"multimodal_config",
|
||||||
"limit_mm_per_prompt",
|
"limit_mm_per_prompt",
|
||||||
"media_io_kwargs",
|
"media_io_kwargs",
|
||||||
@ -440,7 +440,8 @@ class ModelConfig:
|
|||||||
self.model = maybe_model_redirect(self.model)
|
self.model = maybe_model_redirect(self.model)
|
||||||
# The tokenizer is consistent with the model by default.
|
# The tokenizer is consistent with the model by default.
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
if check_gguf_file(self.model):
|
# Check if this is a GGUF model (either local file or remote GGUF)
|
||||||
|
if is_gguf(self.model):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Using a tokenizer is mandatory when loading a GGUF model. "
|
"Using a tokenizer is mandatory when loading a GGUF model. "
|
||||||
"Please specify the tokenizer path or name using the "
|
"Please specify the tokenizer path or name using the "
|
||||||
@ -642,18 +643,6 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Init pooler config if needed
|
# Init pooler config if needed
|
||||||
if self.runner_type == "pooling":
|
if self.runner_type == "pooling":
|
||||||
if self.override_pooler_config is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"`override_pooler_config` is deprecated and will be "
|
|
||||||
"removed in v0.12.0 or v1.0.0, whichever is sooner. "
|
|
||||||
"Please use `pooler_config` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(self.override_pooler_config, dict):
|
|
||||||
self.pooler_config = PoolerConfig(**self.override_pooler_config)
|
|
||||||
else:
|
|
||||||
self.pooler_config = self.override_pooler_config
|
|
||||||
|
|
||||||
if self.pooler_config is None:
|
if self.pooler_config is None:
|
||||||
self.pooler_config = PoolerConfig()
|
self.pooler_config = PoolerConfig()
|
||||||
|
|
||||||
@ -832,7 +821,10 @@ class ModelConfig:
|
|||||||
self.tokenizer = object_storage_tokenizer.dir
|
self.tokenizer = object_storage_tokenizer.dir
|
||||||
|
|
||||||
def _get_encoder_config(self):
|
def _get_encoder_config(self):
|
||||||
return get_sentence_transformer_tokenizer_config(self.model, self.revision)
|
model = self.model
|
||||||
|
if is_remote_gguf(model):
|
||||||
|
model, _ = split_remote_gguf(model)
|
||||||
|
return get_sentence_transformer_tokenizer_config(model, self.revision)
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||||
|
|
||||||
@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
|
|||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
@ -216,7 +216,7 @@ class MultiModalConfig:
|
|||||||
if self.mm_encoder_attn_backend is not None
|
if self.mm_encoder_attn_backend is not None
|
||||||
else None
|
else None
|
||||||
]
|
]
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def get_limit_per_prompt(self, modality: str) -> int:
|
def get_limit_per_prompt(self, modality: str) -> int:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
|
|||||||
|
|
||||||
from vllm import version
|
from vllm import version
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ class ObservabilityConfig:
|
|||||||
# no factors to consider.
|
# no factors to consider.
|
||||||
# this config will not affect the computation graph.
|
# this config will not affect the computation graph.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@field_validator("show_hidden_metrics_for_version")
|
@field_validator("show_hidden_metrics_for_version")
|
||||||
|
|||||||
@ -593,9 +593,10 @@ class ParallelConfig:
|
|||||||
"max_parallel_loading_workers is currently "
|
"max_parallel_loading_workers is currently "
|
||||||
"not supported and will be ignored."
|
"not supported and will be ignored."
|
||||||
)
|
)
|
||||||
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
|
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"nnodes > 1 can only be set when distributed exectuor backend is mp."
|
"nnodes > 1 can only be set when distributed executor "
|
||||||
|
"backend is mp or uni."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class PoolerConfig:
|
|||||||
# no factors to consider.
|
# no factors to consider.
|
||||||
# this config will not affect the computation graph.
|
# this config will not affect the computation graph.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import InitVar
|
from dataclasses import InitVar
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||||
@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated
|
|||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -178,7 +178,7 @@ class SchedulerConfig:
|
|||||||
# no factors to consider.
|
# no factors to consider.
|
||||||
# this config will not affect the computation graph.
|
# this config will not affect the computation graph.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import hashlib
|
|
||||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||||
|
|
||||||
from pydantic import Field, SkipValidation, model_validator
|
from pydantic import Field, SkipValidation, model_validator
|
||||||
@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
|
|||||||
from vllm.config.parallel import ParallelConfig
|
from vllm.config.parallel import ParallelConfig
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -162,7 +162,7 @@ class SpeculativeConfig:
|
|||||||
# Eagle3 affects the computation graph because it returns intermediate
|
# Eagle3 affects the computation graph because it returns intermediate
|
||||||
# hidden states in addition to the final hidden state.
|
# hidden states in addition to the final hidden state.
|
||||||
factors.append(self.method == "eagle3")
|
factors.append(self.method == "eagle3")
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
StructuredOutputsBackend = Literal[
|
StructuredOutputsBackend = Literal[
|
||||||
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
|
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
|
||||||
@ -58,7 +58,7 @@ class StructuredOutputsConfig:
|
|||||||
# no factors to consider.
|
# no factors to consider.
|
||||||
# this config will not affect the computation graph.
|
# this config will not affect the computation graph.
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import getpass
|
import getpass
|
||||||
import hashlib
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes
|
|||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
from .cache import CacheConfig
|
from .cache import CacheConfig
|
||||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||||
@ -193,7 +193,7 @@ class VllmConfig:
|
|||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
if self.additional_config:
|
if self.additional_config:
|
||||||
if isinstance(additional_config := self.additional_config, dict):
|
if isinstance(additional_config := self.additional_config, dict):
|
||||||
additional_config_hash = hashlib.md5(
|
additional_config_hash = safe_hash(
|
||||||
json.dumps(additional_config, sort_keys=True).encode(),
|
json.dumps(additional_config, sort_keys=True).encode(),
|
||||||
usedforsecurity=False,
|
usedforsecurity=False,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
@ -204,9 +204,9 @@ class VllmConfig:
|
|||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
factors.append(vllm_factors)
|
factors.append(vllm_factors)
|
||||||
|
|
||||||
hash_str = hashlib.md5(
|
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||||
str(factors).encode(), usedforsecurity=False
|
:10
|
||||||
).hexdigest()[:10]
|
]
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
if mm_hashes:
|
if mm_hashes:
|
||||||
mm_str = "-".join(mm_hashes)
|
mm_str = "-".join(mm_hashes)
|
||||||
token_bytes += mm_str.encode("utf-8")
|
token_bytes += mm_str.encode("utf-8")
|
||||||
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
|
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||||
if create_folder:
|
if create_folder:
|
||||||
|
|||||||
@ -51,6 +51,7 @@ from vllm.distributed.utils import StatelessProcessGroup
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
from vllm.utils.network_utils import get_distributed_init_method
|
from vllm.utils.network_utils import get_distributed_init_method
|
||||||
|
from vllm.utils.system_utils import suppress_stdout
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
@ -329,7 +330,8 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
# a group with `gloo` backend, to allow direct coordination between
|
# a group with `gloo` backend, to allow direct coordination between
|
||||||
# processes through the CPU.
|
# processes through the CPU.
|
||||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
with suppress_stdout():
|
||||||
|
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||||
if self.rank in ranks:
|
if self.rank in ranks:
|
||||||
self.ranks = ranks
|
self.ranks = ranks
|
||||||
self.world_size = len(ranks)
|
self.world_size = len(ranks)
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from torch.distributed.rendezvous import rendezvous
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.network_utils import get_tcp_uri
|
from vllm.utils.network_utils import get_tcp_uri
|
||||||
|
from vllm.utils.system_utils import suppress_stdout
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -427,33 +428,34 @@ def init_gloo_process_group(
|
|||||||
Stateless init ProcessGroup with gloo backend compatible with
|
Stateless init ProcessGroup with gloo backend compatible with
|
||||||
different torch versions.
|
different torch versions.
|
||||||
"""
|
"""
|
||||||
if is_torch_equal_or_newer("2.6"):
|
with suppress_stdout():
|
||||||
pg = ProcessGroup(
|
if is_torch_equal_or_newer("2.6"):
|
||||||
prefix_store,
|
pg = ProcessGroup(
|
||||||
group_rank,
|
prefix_store,
|
||||||
group_size,
|
group_rank,
|
||||||
)
|
group_size,
|
||||||
else:
|
)
|
||||||
options = ProcessGroup.Options(backend="gloo")
|
else:
|
||||||
pg = ProcessGroup(
|
options = ProcessGroup.Options(backend="gloo")
|
||||||
prefix_store,
|
pg = ProcessGroup(
|
||||||
group_rank,
|
prefix_store,
|
||||||
group_size,
|
group_rank,
|
||||||
options,
|
group_size,
|
||||||
)
|
options,
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
)
|
||||||
|
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||||
|
|
||||||
backend_class = ProcessGroupGloo(
|
backend_class = ProcessGroupGloo(
|
||||||
prefix_store, group_rank, group_size, timeout=timeout
|
prefix_store, group_rank, group_size, timeout=timeout
|
||||||
)
|
)
|
||||||
backend_type = ProcessGroup.BackendType.GLOO
|
backend_type = ProcessGroup.BackendType.GLOO
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if is_torch_equal_or_newer("2.6"):
|
if is_torch_equal_or_newer("2.6"):
|
||||||
# _set_default_backend is supported in torch >= 2.6
|
# _set_default_backend is supported in torch >= 2.6
|
||||||
pg._set_default_backend(backend_type)
|
pg._set_default_backend(backend_type)
|
||||||
backend_class._set_sequence_number_for_group()
|
backend_class._set_sequence_number_for_group()
|
||||||
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
pg._register_backend(device, backend_type, backend_class)
|
||||||
return pg
|
return pg
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ import regex as re
|
|||||||
import torch
|
import torch
|
||||||
from pydantic import TypeAdapter, ValidationError
|
from pydantic import TypeAdapter, ValidationError
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from typing_extensions import TypeIs, deprecated
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
@ -77,7 +77,7 @@ from vllm.config.observability import DetailedTraceModules
|
|||||||
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
|
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
|
||||||
from vllm.config.scheduler import SchedulerPolicy
|
from vllm.config.scheduler import SchedulerPolicy
|
||||||
from vllm.config.utils import get_field
|
from vllm.config.utils import get_field
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger, suppress_logging
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
|
from vllm.ray.lazy_utils import is_in_ray_actor, is_ray_initialized
|
||||||
@ -86,7 +86,7 @@ from vllm.transformers_utils.config import (
|
|||||||
is_interleaved,
|
is_interleaved,
|
||||||
maybe_override_with_speculators,
|
maybe_override_with_speculators,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage
|
from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.network_utils import get_ip
|
from vllm.utils.network_utils import get_ip
|
||||||
@ -247,11 +247,13 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
|
|||||||
default = field.default
|
default = field.default
|
||||||
# Handle pydantic.Field defaults
|
# Handle pydantic.Field defaults
|
||||||
if isinstance(default, FieldInfo):
|
if isinstance(default, FieldInfo):
|
||||||
default = (
|
if default.default_factory is None:
|
||||||
default.default
|
default = default.default
|
||||||
if default.default_factory is None
|
else:
|
||||||
else default.default_factory()
|
# VllmConfig's Fields have default_factory set to config classes.
|
||||||
)
|
# These could emit logs on init, which would be confusing.
|
||||||
|
with suppress_logging():
|
||||||
|
default = default.default_factory()
|
||||||
elif field.default_factory is not MISSING:
|
elif field.default_factory is not MISSING:
|
||||||
default = field.default_factory()
|
default = field.default_factory()
|
||||||
|
|
||||||
@ -518,9 +520,6 @@ class EngineArgs:
|
|||||||
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
|
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
|
||||||
|
|
||||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||||
override_pooler_config: dict | PoolerConfig | None = (
|
|
||||||
ModelConfig.override_pooler_config
|
|
||||||
)
|
|
||||||
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config")
|
||||||
worker_cls: str = ParallelConfig.worker_cls
|
worker_cls: str = ParallelConfig.worker_cls
|
||||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||||
@ -657,11 +656,6 @@ class EngineArgs:
|
|||||||
)
|
)
|
||||||
model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
|
model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"])
|
||||||
model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
|
model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"])
|
||||||
model_group.add_argument(
|
|
||||||
"--override-pooler-config",
|
|
||||||
**model_kwargs["override_pooler_config"],
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
|
"--logits-processor-pattern", **model_kwargs["logits_processor_pattern"]
|
||||||
)
|
)
|
||||||
@ -878,7 +872,11 @@ class EngineArgs:
|
|||||||
"--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
|
"--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"]
|
||||||
)
|
)
|
||||||
cache_group.add_argument(
|
cache_group.add_argument(
|
||||||
"--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"]
|
"--enable-prefix-caching",
|
||||||
|
**{
|
||||||
|
**cache_kwargs["enable_prefix_caching"],
|
||||||
|
"default": None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
cache_group.add_argument(
|
cache_group.add_argument(
|
||||||
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
|
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
|
||||||
@ -1142,8 +1140,8 @@ class EngineArgs:
|
|||||||
return engine_args
|
return engine_args
|
||||||
|
|
||||||
def create_model_config(self) -> ModelConfig:
|
def create_model_config(self) -> ModelConfig:
|
||||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
# gguf file needs a specific model loader
|
||||||
if check_gguf_file(self.model):
|
if is_gguf(self.model):
|
||||||
self.quantization = self.load_format = "gguf"
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
||||||
@ -1237,7 +1235,6 @@ class EngineArgs:
|
|||||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||||
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
|
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
|
||||||
pooler_config=self.pooler_config,
|
pooler_config=self.pooler_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
|
||||||
logits_processor_pattern=self.logits_processor_pattern,
|
logits_processor_pattern=self.logits_processor_pattern,
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config,
|
||||||
override_generation_config=self.override_generation_config,
|
override_generation_config=self.override_generation_config,
|
||||||
@ -1810,9 +1807,11 @@ class EngineArgs:
|
|||||||
if model_config.runner_type != "pooling":
|
if model_config.runner_type != "pooling":
|
||||||
default_chunked_prefill = True
|
default_chunked_prefill = True
|
||||||
|
|
||||||
# Disable prefix caching default for hybrid models
|
# Disable prefix caching default for hybrid models and mamba-only
|
||||||
# since the feature is still experimental.
|
# models since the feature is still experimental.
|
||||||
default_prefix_caching = not model_config.is_hybrid
|
default_prefix_caching = not (
|
||||||
|
model_config.is_hybrid or model_config.is_attention_free
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert model_config.pooler_config is not None
|
assert model_config.pooler_config is not None
|
||||||
|
|
||||||
@ -2039,24 +2038,6 @@ class AsyncEngineArgs(EngineArgs):
|
|||||||
|
|
||||||
enable_log_requests: bool = False
|
enable_log_requests: bool = False
|
||||||
|
|
||||||
@property
|
|
||||||
@deprecated(
|
|
||||||
"`disable_log_requests` is deprecated and has been replaced with "
|
|
||||||
"`enable_log_requests`. This will be removed in v0.12.0. Please use "
|
|
||||||
"`enable_log_requests` instead."
|
|
||||||
)
|
|
||||||
def disable_log_requests(self) -> bool:
|
|
||||||
return not self.enable_log_requests
|
|
||||||
|
|
||||||
@disable_log_requests.setter
|
|
||||||
@deprecated(
|
|
||||||
"`disable_log_requests` is deprecated and has been replaced with "
|
|
||||||
"`enable_log_requests`. This will be removed in v0.12.0. Please use "
|
|
||||||
"`enable_log_requests` instead."
|
|
||||||
)
|
|
||||||
def disable_log_requests(self, value: bool):
|
|
||||||
self.enable_log_requests = not value
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
parser: FlexibleArgumentParser, async_args_only: bool = False
|
parser: FlexibleArgumentParser, async_args_only: bool = False
|
||||||
|
|||||||
@ -174,9 +174,6 @@ class LLM:
|
|||||||
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||||
pooler_config: Initialize non-default pooling config for the pooling
|
pooler_config: Initialize non-default pooling config for the pooling
|
||||||
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||||
override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
|
|
||||||
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
|
|
||||||
whichever is sooner.
|
|
||||||
compilation_config: Either an integer or a dictionary. If it is an
|
compilation_config: Either an integer or a dictionary. If it is an
|
||||||
integer, it is used as the mode of compilation optimization. If it
|
integer, it is used as the mode of compilation optimization. If it
|
||||||
is a dictionary, it can specify the full compilation configuration.
|
is a dictionary, it can specify the full compilation configuration.
|
||||||
@ -214,7 +211,6 @@ class LLM:
|
|||||||
hf_overrides: HfOverrides | None = None,
|
hf_overrides: HfOverrides | None = None,
|
||||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||||
pooler_config: PoolerConfig | None = None,
|
pooler_config: PoolerConfig | None = None,
|
||||||
override_pooler_config: PoolerConfig | None = None,
|
|
||||||
structured_outputs_config: dict[str, Any]
|
structured_outputs_config: dict[str, Any]
|
||||||
| StructuredOutputsConfig
|
| StructuredOutputsConfig
|
||||||
| None = None,
|
| None = None,
|
||||||
@ -330,7 +326,6 @@ class LLM:
|
|||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
pooler_config=pooler_config,
|
pooler_config=pooler_config,
|
||||||
override_pooler_config=override_pooler_config,
|
|
||||||
structured_outputs_config=structured_outputs_instance,
|
structured_outputs_config=structured_outputs_instance,
|
||||||
compilation_config=compilation_config_instance,
|
compilation_config=compilation_config_instance,
|
||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
|
|||||||
@ -29,7 +29,6 @@ from openai.types.responses import (
|
|||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
ResponseOutputItemDoneEvent,
|
ResponseOutputItemDoneEvent,
|
||||||
ResponsePrompt,
|
ResponsePrompt,
|
||||||
ResponseReasoningItem,
|
|
||||||
ResponseReasoningTextDeltaEvent,
|
ResponseReasoningTextDeltaEvent,
|
||||||
ResponseReasoningTextDoneEvent,
|
ResponseReasoningTextDoneEvent,
|
||||||
ResponseStatus,
|
ResponseStatus,
|
||||||
@ -304,9 +303,7 @@ def get_logits_processors(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
ResponseInputOutputItem: TypeAlias = (
|
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
|
||||||
ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ResponsesRequest(OpenAIBaseModel):
|
class ResponsesRequest(OpenAIBaseModel):
|
||||||
@ -559,9 +556,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
) = "none"
|
) = "none"
|
||||||
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
||||||
include_reasoning: bool = True
|
include_reasoning: bool = True
|
||||||
|
parallel_tool_calls: bool | None = True
|
||||||
|
|
||||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
# NOTE this will be ignored by vLLM
|
||||||
parallel_tool_calls: bool | None = False
|
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|
||||||
# --8<-- [start:chat-completion-sampling-params]
|
# --8<-- [start:chat-completion-sampling-params]
|
||||||
|
|||||||
@ -55,6 +55,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_l
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||||
|
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -1206,6 +1207,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
finish_reason_sent[i] = True
|
finish_reason_sent[i] = True
|
||||||
|
|
||||||
|
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
@ -1531,6 +1533,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
as_list(output.token_ids) if request.return_token_ids else None
|
as_list(output.token_ids) if request.return_token_ids else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
|||||||
@ -296,11 +296,7 @@ class OpenAIServing:
|
|||||||
parser = None
|
parser = None
|
||||||
if not enable_auto_tools or tool_parser_name is None:
|
if not enable_auto_tools or tool_parser_name is None:
|
||||||
return parser
|
return parser
|
||||||
logger.info(
|
logger.info('"auto" tool choice has been enabled.')
|
||||||
'"auto" tool choice has been enabled please note that while'
|
|
||||||
" the parallel_tool_calls client option is preset for "
|
|
||||||
"compatibility reasons, it will be ignored."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
|
if tool_parser_name == "pythonic" and self.model_config.model.startswith(
|
||||||
|
|||||||
@ -94,7 +94,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.responses_utils import (
|
from vllm.entrypoints.responses_utils import (
|
||||||
construct_chat_message_with_tool_call,
|
construct_input_messages,
|
||||||
convert_tool_responses_to_completions_format,
|
convert_tool_responses_to_completions_format,
|
||||||
extract_tool_types,
|
extract_tool_types,
|
||||||
)
|
)
|
||||||
@ -504,7 +504,12 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
for tool in request.tools
|
for tool in request.tools
|
||||||
]
|
]
|
||||||
# Construct the input messages.
|
# Construct the input messages.
|
||||||
messages = self._construct_input_messages(request, prev_response)
|
messages = construct_input_messages(
|
||||||
|
request_instructions=request.instructions,
|
||||||
|
request_input=request.input,
|
||||||
|
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
|
||||||
|
prev_response_output=prev_response.output if prev_response else None,
|
||||||
|
)
|
||||||
_, request_prompts, engine_prompts = await self._preprocess_chat(
|
_, request_prompts, engine_prompts = await self._preprocess_chat(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -869,47 +874,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
output_items.extend(last_items)
|
output_items.extend(last_items)
|
||||||
return output_items
|
return output_items
|
||||||
|
|
||||||
def _construct_input_messages(
|
|
||||||
self,
|
|
||||||
request: ResponsesRequest,
|
|
||||||
prev_response: ResponsesResponse | None = None,
|
|
||||||
) -> list[ChatCompletionMessageParam]:
|
|
||||||
messages: list[ChatCompletionMessageParam] = []
|
|
||||||
if request.instructions:
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": request.instructions,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepend the conversation history.
|
|
||||||
if prev_response is not None:
|
|
||||||
# Add the previous messages.
|
|
||||||
prev_msg = self.msg_store[prev_response.id]
|
|
||||||
messages.extend(prev_msg)
|
|
||||||
|
|
||||||
# Add the previous output.
|
|
||||||
for output_item in prev_response.output:
|
|
||||||
# NOTE: We skip the reasoning output.
|
|
||||||
if isinstance(output_item, ResponseOutputMessage):
|
|
||||||
for content in output_item.content:
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": content.text,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append the new input.
|
|
||||||
# Responses API supports simple text inputs without chat format.
|
|
||||||
if isinstance(request.input, str):
|
|
||||||
messages.append({"role": "user", "content": request.input})
|
|
||||||
else:
|
|
||||||
for item in request.input:
|
|
||||||
messages.append(construct_chat_message_with_tool_call(item))
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _construct_harmony_system_input_message(
|
def _construct_harmony_system_input_message(
|
||||||
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str]
|
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str]
|
||||||
) -> OpenAIHarmonyMessage:
|
) -> OpenAIHarmonyMessage:
|
||||||
|
|||||||
37
vllm/entrypoints/openai/utils.py
Normal file
37
vllm/entrypoints/openai/utils.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used internally
|
||||||
|
_ChatCompletionResponseChoiceT = TypeVar(
|
||||||
|
"_ChatCompletionResponseChoiceT",
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_filter_parallel_tool_calls(
|
||||||
|
choice: _ChatCompletionResponseChoiceT, request: ChatCompletionRequest
|
||||||
|
) -> _ChatCompletionResponseChoiceT:
|
||||||
|
"""Filter to first tool call only when parallel_tool_calls is False."""
|
||||||
|
|
||||||
|
if request.parallel_tool_calls:
|
||||||
|
return choice
|
||||||
|
|
||||||
|
if isinstance(choice, ChatCompletionResponseChoice) and choice.message.tool_calls:
|
||||||
|
choice.message.tool_calls = choice.message.tool_calls[:1]
|
||||||
|
elif (
|
||||||
|
isinstance(choice, ChatCompletionResponseStreamChoice)
|
||||||
|
and choice.delta.tool_calls
|
||||||
|
):
|
||||||
|
choice.delta.tool_calls = [
|
||||||
|
tool_call for tool_call in choice.delta.tool_calls if tool_call.index == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
return choice
|
||||||
@ -9,7 +9,11 @@ from openai.types.chat import (
|
|||||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
Function as FunctionCallTool,
|
Function as FunctionCallTool,
|
||||||
)
|
)
|
||||||
from openai.types.responses import ResponseFunctionToolCall
|
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
|
||||||
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
|
ResponseFunctionToolCallOutputItem,
|
||||||
|
)
|
||||||
|
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||||
from openai.types.responses.tool import Tool
|
from openai.types.responses.tool import Tool
|
||||||
|
|
||||||
@ -20,6 +24,49 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_input_messages(
|
||||||
|
*,
|
||||||
|
request_instructions: str | None = None,
|
||||||
|
request_input: str | list[ResponseInputOutputItem],
|
||||||
|
prev_msg: list[ChatCompletionMessageParam] | None = None,
|
||||||
|
prev_response_output: list[ResponseOutputItem] | None = None,
|
||||||
|
):
|
||||||
|
messages: list[ChatCompletionMessageParam] = []
|
||||||
|
if request_instructions:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": request_instructions,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepend the conversation history.
|
||||||
|
if prev_msg is not None:
|
||||||
|
# Add the previous messages.
|
||||||
|
messages.extend(prev_msg)
|
||||||
|
if prev_response_output is not None:
|
||||||
|
# Add the previous output.
|
||||||
|
for output_item in prev_response_output:
|
||||||
|
# NOTE: We skip the reasoning output.
|
||||||
|
if isinstance(output_item, ResponseOutputMessage):
|
||||||
|
for content in output_item.content:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content.text,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append the new input.
|
||||||
|
# Responses API supports simple text inputs without chat format.
|
||||||
|
if isinstance(request_input, str):
|
||||||
|
messages.append({"role": "user", "content": request_input})
|
||||||
|
else:
|
||||||
|
for item in request_input:
|
||||||
|
messages.append(construct_chat_message_with_tool_call(item))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def construct_chat_message_with_tool_call(
|
def construct_chat_message_with_tool_call(
|
||||||
item: ResponseInputOutputItem,
|
item: ResponseInputOutputItem,
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
@ -50,6 +97,12 @@ def construct_chat_message_with_tool_call(
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"reasoning": reasoning_content,
|
"reasoning": reasoning_content,
|
||||||
}
|
}
|
||||||
|
elif isinstance(item, ResponseFunctionToolCallOutputItem):
|
||||||
|
return ChatCompletionToolMessageParam(
|
||||||
|
role="tool",
|
||||||
|
content=item.output,
|
||||||
|
tool_call_id=item.call_id,
|
||||||
|
)
|
||||||
elif item.get("type") == "function_call_output":
|
elif item.get("type") == "function_call_output":
|
||||||
# Append the function call output as a tool message.
|
# Append the function call output as a tool message.
|
||||||
return ChatCompletionToolMessageParam(
|
return ChatCompletionToolMessageParam(
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Hashable
|
from collections.abc import Generator, Hashable
|
||||||
|
from contextlib import contextmanager
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
@ -212,6 +213,14 @@ def init_logger(name: str) -> _VllmLogger:
|
|||||||
return cast(_VllmLogger, logger)
|
return cast(_VllmLogger, logger)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
|
||||||
|
current_level = logging.root.manager.disable
|
||||||
|
logging.disable(level)
|
||||||
|
yield
|
||||||
|
logging.disable(current_level)
|
||||||
|
|
||||||
|
|
||||||
# The root logger is initialized when the module is imported.
|
# The root logger is initialized when the module is imported.
|
||||||
# This is thread-safe as the module is only imported once,
|
# This is thread-safe as the module is only imported once,
|
||||||
# guaranteed by the Python GIL.
|
# guaranteed by the Python GIL.
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP,
|
TopKWeightAndReduceNoOP,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.import_utils import has_triton_kernels
|
from vllm.utils.import_utils import has_triton_kernels
|
||||||
|
|
||||||
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
|
|||||||
gating_output, topk, sm_first=not renormalize
|
gating_output, topk, sm_first=not renormalize
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
return triton_kernel_fused_experts(
|
return triton_kernel_fused_experts(
|
||||||
None,
|
output,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
routing_data,
|
routing_data,
|
||||||
gather_idx,
|
gather_idx,
|
||||||
scatter_idx,
|
scatter_idx,
|
||||||
|
topk=topk,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
|
|||||||
routing_data, # RoutingData
|
routing_data, # RoutingData
|
||||||
gather_indx, # GatherIndx
|
gather_indx, # GatherIndx
|
||||||
scatter_indx, # ScatterIndx
|
scatter_indx, # ScatterIndx
|
||||||
|
topk: int,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
quant_config: FusedMoEQuantConfig | None = None,
|
quant_config: FusedMoEQuantConfig | None = None,
|
||||||
swiglu_alpha: float = 1.702,
|
swiglu_alpha: float = 1.702,
|
||||||
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
|
|||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: torch.Tensor | None = None,
|
expert_map: torch.Tensor | None = None,
|
||||||
|
intermediate_cache: torch.Tensor | None = None,
|
||||||
a1q_scale: torch.Tensor | None = None,
|
a1q_scale: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
|
|||||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||||
|
|
||||||
# Shape check, only check non-mxfp4
|
# Shape check, only check non-mxfp4
|
||||||
|
assert hidden_states.ndim == 2
|
||||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||||
assert w2.shape[-1] == w1.shape[1]
|
assert w2.shape[-1] == w1.shape[1]
|
||||||
|
|
||||||
|
batch_dim = 1
|
||||||
|
M, K = hidden_states.shape[-2:]
|
||||||
E, _, N = w1.shape
|
E, _, N = w1.shape
|
||||||
|
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
global_num_experts = E
|
global_num_experts = E
|
||||||
|
|
||||||
|
if intermediate_cache is None:
|
||||||
|
intermediate_cache = torch.empty(
|
||||||
|
(batch_dim, M * topk, N // 2),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add batch_dim to output buffer because matmul_ogs expects 3D output
|
||||||
|
intermediate_cache = _resize_cache(
|
||||||
|
intermediate_cache, (batch_dim, M * topk, N // 2)
|
||||||
|
)
|
||||||
|
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||||
|
|
||||||
act = FusedActivation(
|
act = FusedActivation(
|
||||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||||
(swiglu_alpha, swiglu_limit),
|
(swiglu_alpha, swiglu_limit),
|
||||||
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
|
|||||||
)
|
)
|
||||||
gammas = routing_data.gate_scal if routing_data else None
|
gammas = routing_data.gate_scal if routing_data else None
|
||||||
|
|
||||||
intermediate_cache1 = matmul_ogs(
|
matmul_ogs(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
quant_config.w1_bias,
|
quant_config.w1_bias,
|
||||||
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
|
|||||||
precision_config=quant_config.w1_precision,
|
precision_config=quant_config.w1_precision,
|
||||||
gammas=gammas if apply_router_weight_on_input else None,
|
gammas=gammas if apply_router_weight_on_input else None,
|
||||||
fused_activation=act,
|
fused_activation=act,
|
||||||
|
y=intermediate_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_cache3 = matmul_ogs(
|
matmul_ogs(
|
||||||
intermediate_cache1,
|
intermediate_cache.view(M * topk, N // 2),
|
||||||
w2,
|
w2,
|
||||||
quant_config.w2_bias,
|
quant_config.w2_bias,
|
||||||
routing_data,
|
routing_data,
|
||||||
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
|
|||||||
gammas=None if apply_router_weight_on_input else gammas,
|
gammas=None if apply_router_weight_on_input else gammas,
|
||||||
y=output_tensor,
|
y=output_tensor,
|
||||||
)
|
)
|
||||||
return intermediate_cache3
|
output_tensor = output_tensor.view(M, K)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
def make_routing_data(
|
def make_routing_data(
|
||||||
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def supports_expert_map(self) -> bool:
|
def supports_expert_map(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def moe_problem_size(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> tuple[int, int, int, int, int]:
|
||||||
|
"""
|
||||||
|
Extract the MoE problem size from the given tensor arguments:
|
||||||
|
- a: The hidden states, input to the MoE layer.
|
||||||
|
- w1: The first set of expert weights.
|
||||||
|
- w2: The second set of expert weights.
|
||||||
|
- topk_ids: The topk ids.
|
||||||
|
Note: extracting the problem shape from the weight and activation
|
||||||
|
tensors is not obvious. It needs to be done this way specifically
|
||||||
|
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||||
|
divide the trailing dimension by two, so it's not "correct" to
|
||||||
|
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||||
|
some kernels transpose the weights, so this needs to be kept in mind.
|
||||||
|
Note: This implementation covers most cases. However, if experts
|
||||||
|
require a specialized implementation, like MarlinExperts, they are free
|
||||||
|
to override this function.
|
||||||
|
"""
|
||||||
|
assert w1.dim() == 3 and w2.dim() == 3
|
||||||
|
E, _, N = w1.size()
|
||||||
|
K = a1.size(-1)
|
||||||
|
|
||||||
|
assert a1.dim() == 2
|
||||||
|
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||||
|
M = a1.size(0)
|
||||||
|
|
||||||
|
assert topk_ids.dim() == 2
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
|
return E, M, N, K, topk
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
# Weight application and reduction happens in the fused_experts kernel.
|
# Weight application and reduction happens in the fused_experts kernel.
|
||||||
return TopKWeightAndReduceNoOP()
|
return TopKWeightAndReduceNoOP()
|
||||||
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
|
|||||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||||
# workspace are allocated inside the kernel
|
# workspace are allocated inside the kernel
|
||||||
workspace1 = (M, K)
|
workspace1 = (0, 0)
|
||||||
workspace2 = (0, 0)
|
workspace2 = (M * topk, N // 2)
|
||||||
output = (M, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output)
|
return (workspace1, workspace2, output)
|
||||||
|
|
||||||
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
|
|||||||
topk_ids, topk_weights, local_num_experts
|
topk_ids, topk_weights, local_num_experts
|
||||||
)
|
)
|
||||||
|
|
||||||
experts_output = triton_kernel_fused_experts(
|
topk = topk_ids.size(1)
|
||||||
None,
|
triton_kernel_fused_experts(
|
||||||
|
output,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
routing_data,
|
routing_data,
|
||||||
gather_indx,
|
gather_indx,
|
||||||
scatter_indx,
|
scatter_indx,
|
||||||
|
topk=topk,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
apply_router_weight_on_input=False,
|
apply_router_weight_on_input=False,
|
||||||
global_num_experts=local_num_experts,
|
global_num_experts=local_num_experts,
|
||||||
expert_map=None, # applied already
|
expert_map=None, # applied already
|
||||||
|
intermediate_cache=workspace2,
|
||||||
a1q_scale=a1q_scale,
|
a1q_scale=a1q_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
output.copy_(experts_output, non_blocking=True)
|
|
||||||
|
|||||||
@ -103,7 +103,7 @@ __all__ = [
|
|||||||
"CompressedTensorsW8A8Int8MoEMethod",
|
"CompressedTensorsW8A8Int8MoEMethod",
|
||||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||||
"CompressedTensorsWNA16MoEMethod",
|
"CompressedTensorsWNA16MoEMethod",
|
||||||
"CompressedTensorsW4A4MoeMethod",
|
"CompressedTensorsW4A4Nvfp4MoeMethod",
|
||||||
"CompressedTensorsW4A8Int8MoEMethod",
|
"CompressedTensorsW4A8Int8MoEMethod",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config, layer.moe_config
|
quant_config, layer.moe_config
|
||||||
)
|
)
|
||||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A4MoeMethod(layer.moe_config)
|
return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
|
||||||
elif (
|
elif (
|
||||||
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||||
@ -188,7 +188,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
|
||||||
def __init__(self, moe: FusedMoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
detect_nvfp4_moe_support,
|
detect_nvfp4_moe_support,
|
||||||
@ -205,8 +205,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||||
" for CompressedTensorsW4A4MoeMethod."
|
" for CompressedTensorsW4A4Nvfp4MoeMethod."
|
||||||
)
|
)
|
||||||
|
elif self.use_marlin:
|
||||||
|
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
|
||||||
|
else:
|
||||||
|
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -612,7 +616,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
assert expert_map is None, (
|
assert expert_map is None, (
|
||||||
"Expert Parallelism / expert_map "
|
"Expert Parallelism / expert_map "
|
||||||
"is currently not supported for "
|
"is currently not supported for "
|
||||||
"CompressedTensorsW4A4MoeMethod."
|
"CompressedTensorsW4A4Nvfp4MoeMethod."
|
||||||
)
|
)
|
||||||
assert self.moe_quant_config is not None
|
assert self.moe_quant_config is not None
|
||||||
|
|
||||||
|
|||||||
@ -1132,6 +1132,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||||
" for ModelOptNvFp4FusedMoE."
|
" for ModelOptNvFp4FusedMoE."
|
||||||
)
|
)
|
||||||
|
elif self.use_marlin:
|
||||||
|
logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.")
|
||||||
|
else:
|
||||||
|
logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.")
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -196,9 +196,10 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
# TODO: Add support for MXFP4 Linear Method.
|
# TODO: Add support for MXFP4 Linear Method.
|
||||||
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
|
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
|
||||||
# if you are interested in enabling MXFP4 here.
|
# if you are interested in enabling MXFP4 here.
|
||||||
logger.warning_once(
|
logger.debug_once(
|
||||||
"MXFP4 linear layer is not implemented - falling back to "
|
"MXFP4 linear layer is not implemented - falling back to "
|
||||||
"UnquantizedLinearMethod."
|
"UnquantizedLinearMethod.",
|
||||||
|
scope="local",
|
||||||
)
|
)
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
@ -208,9 +209,10 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
return Mxfp4MoEMethod(layer.moe_config)
|
return Mxfp4MoEMethod(layer.moe_config)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
# TODO: Add support for MXFP4 Attention.
|
# TODO: Add support for MXFP4 Attention.
|
||||||
logger.warning_once(
|
logger.debug_once(
|
||||||
"MXFP4 attention layer is not implemented. "
|
"MXFP4 attention layer is not implemented. "
|
||||||
"Skipping quantization for this layer."
|
"Skipping quantization for this layer.",
|
||||||
|
scope="local",
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
|
|||||||
process_weights_after_loading,
|
process_weights_after_loading,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
download_gguf,
|
||||||
get_gguf_extra_tensor_names,
|
get_gguf_extra_tensor_names,
|
||||||
get_gguf_weight_type_map,
|
get_gguf_weight_type_map,
|
||||||
gguf_quant_weights_iterator,
|
gguf_quant_weights_iterator,
|
||||||
@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
f"load format {load_config.load_format}"
|
f"load format {load_config.load_format}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_weights(self, model_name_or_path: str):
|
def _prepare_weights(self, model_config: ModelConfig):
|
||||||
|
model_name_or_path = model_config.model
|
||||||
if os.path.isfile(model_name_or_path):
|
if os.path.isfile(model_name_or_path):
|
||||||
return model_name_or_path
|
return model_name_or_path
|
||||||
# for raw HTTPS link
|
# for raw HTTPS link
|
||||||
@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||||
else:
|
# repo_id:quant_type
|
||||||
raise ValueError(
|
elif "/" in model_name_or_path and ":" in model_name_or_path:
|
||||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
repo_id, quant_type = model_name_or_path.rsplit(":", 1)
|
||||||
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)"
|
return download_gguf(
|
||||||
|
repo_id,
|
||||||
|
quant_type,
|
||||||
|
cache_dir=self.load_config.download_dir,
|
||||||
|
revision=model_config.revision,
|
||||||
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||||
|
"(expected local file, raw URL, <repo_id>/<filename>.gguf, "
|
||||||
|
"or <repo_id>:<quant_type>)"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||||
"""
|
"""
|
||||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||||
@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
gguf_to_hf_name_map: dict[str, str],
|
gguf_to_hf_name_map: dict[str, str],
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
weight_type_map = get_gguf_weight_type_map(
|
weight_type_map = get_gguf_weight_type_map(
|
||||||
model_config.model, gguf_to_hf_name_map
|
model_name_or_path, gguf_to_hf_name_map
|
||||||
)
|
)
|
||||||
is_multimodal = hasattr(model_config.hf_config, "vision_config")
|
is_multimodal = hasattr(model_config.hf_config, "vision_config")
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
||||||
|
|
||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model)
|
self._prepare_weights(model_config)
|
||||||
|
|
||||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||||
local_model_path = self._prepare_weights(model_config.model)
|
local_model_path = self._prepare_weights(model_config)
|
||||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||||
model.load_weights(
|
model.load_weights(
|
||||||
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
|
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
|
||||||
@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
local_model_path = self._prepare_weights(model_config.model)
|
local_model_path = self._prepare_weights(model_config)
|
||||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||||
# we can only know if tie word embeddings after mapping weights
|
# we can only know if tie word embeddings after mapping weights
|
||||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||||
|
|||||||
@ -369,6 +369,52 @@ def get_sparse_attention_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def download_gguf(
|
||||||
|
repo_id: str,
|
||||||
|
quant_type: str,
|
||||||
|
cache_dir: str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
ignore_patterns: str | list[str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
# Use patterns that snapshot_download can handle directly
|
||||||
|
# Patterns to match:
|
||||||
|
# - *-{quant_type}.gguf (root)
|
||||||
|
# - *-{quant_type}-*.gguf (root sharded)
|
||||||
|
# - */*-{quant_type}.gguf (subdir)
|
||||||
|
# - */*-{quant_type}-*.gguf (subdir sharded)
|
||||||
|
allow_patterns = [
|
||||||
|
f"*-{quant_type}.gguf",
|
||||||
|
f"*-{quant_type}-*.gguf",
|
||||||
|
f"*/*-{quant_type}.gguf",
|
||||||
|
f"*/*-{quant_type}-*.gguf",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use download_weights_from_hf which handles caching and downloading
|
||||||
|
folder = download_weights_from_hf(
|
||||||
|
model_name_or_path=repo_id,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
revision=revision,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the downloaded file(s) in the folder
|
||||||
|
local_files = []
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
# Convert pattern to glob pattern for local filesystem
|
||||||
|
glob_pattern = os.path.join(folder, pattern)
|
||||||
|
local_files.extend(glob.glob(glob_pattern))
|
||||||
|
|
||||||
|
if not local_files:
|
||||||
|
raise ValueError(
|
||||||
|
f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort to ensure consistent ordering (prefer non-sharded files)
|
||||||
|
local_files.sort(key=lambda x: (x.count("-"), x))
|
||||||
|
return local_files[0]
|
||||||
|
|
||||||
|
|
||||||
def download_weights_from_hf(
|
def download_weights_from_hf(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: str | None,
|
cache_dir: str | None,
|
||||||
|
|||||||
@ -233,7 +233,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
position_embedding=position_embedding,
|
position_embedding=position_embedding,
|
||||||
rope_parameters=config.rope_parameters,
|
rope_parameters=getattr(config, "rope_parameters", None),
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
|||||||
@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
if cache_config.mamba_block_size is None:
|
|
||||||
cache_config.mamba_block_size = model_config.max_model_len
|
|
||||||
|
|
||||||
if cache_config.enable_prefix_caching:
|
if cache_config.enable_prefix_caching:
|
||||||
if model_config.supports_mamba_prefix_caching:
|
if model_config.supports_mamba_prefix_caching:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
"Its support for Mamba layers is experimental. "
|
"Its support for Mamba layers is experimental. "
|
||||||
"Please report any issues you may observe."
|
"Please report any issues you may observe."
|
||||||
)
|
)
|
||||||
|
# By default, mamba block size will be set to max_model_len (see
|
||||||
|
# below). When enabling prefix caching, we align mamba block size
|
||||||
|
# to the block size as the basic granularity for prefix caching.
|
||||||
|
if cache_config.mamba_block_size is None:
|
||||||
|
cache_config.mamba_block_size = cache_config.block_size
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Hybrid or mamba-based model detected without "
|
"Hybrid or mamba-based model detected without "
|
||||||
@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
)
|
)
|
||||||
cache_config.enable_prefix_caching = False
|
cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
|
if cache_config.mamba_block_size is None:
|
||||||
|
cache_config.mamba_block_size = model_config.max_model_len
|
||||||
|
|
||||||
# TODO(tdoublep): remove once cascade attention is supported
|
# TODO(tdoublep): remove once cascade attention is supported
|
||||||
logger.info(
|
logger.info(
|
||||||
"Disabling cascade attention since it is not supported for hybrid models."
|
"Disabling cascade attention since it is not supported for hybrid models."
|
||||||
|
|||||||
@ -100,7 +100,7 @@ class GPTJAttention(nn.Module):
|
|||||||
self.head_size,
|
self.head_size,
|
||||||
rotary_dim=config.rotary_dim,
|
rotary_dim=config.rotary_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
rope_parameters=config.rope_parameters,
|
rope_parameters=getattr(config, "rope_parameters", None),
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
|
|||||||
@ -239,7 +239,7 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
max_position=config.max_position_embeddings,
|
max_position=config.max_position_embeddings,
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_parameters=config.rope_parameters,
|
rope_parameters=getattr(config, "rope_parameters", None),
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
|
|||||||
@ -262,7 +262,7 @@ class LlamaAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=self.max_position_embeddings,
|
max_position=self.max_position_embeddings,
|
||||||
rope_parameters=config.rope_parameters,
|
rope_parameters=getattr(config, "rope_parameters", None),
|
||||||
is_neox_style=is_neox_style,
|
is_neox_style=is_neox_style,
|
||||||
partial_rotary_factor=self.partial_rotary_factor,
|
partial_rotary_factor=self.partial_rotary_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
|
|||||||
`tests/models/registry.py` with example HuggingFace models for it.
|
`tests/models/registry.py` with example HuggingFace models for it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -32,6 +31,7 @@ from vllm.config import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logging_utils import logtime
|
from vllm.logging_utils import logtime
|
||||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||||
|
from vllm.utils.hashing import safe_hash
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
has_inner_state,
|
has_inner_state,
|
||||||
@ -654,7 +654,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
|||||||
|
|
||||||
if model_path.exists():
|
if model_path.exists():
|
||||||
with open(model_path, "rb") as f:
|
with open(model_path, "rb") as f:
|
||||||
module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
|
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
mi = self._load_modelinfo_from_cache(module_hash)
|
mi = self._load_modelinfo_from_cache(module_hash)
|
||||||
if mi is not None:
|
if mi is not None:
|
||||||
|
|||||||
@ -407,9 +407,6 @@ class CudaPlatformBase(Platform):
|
|||||||
|
|
||||||
# We have found some valid backends. Select the one with the
|
# We have found some valid backends. Select the one with the
|
||||||
# highest priority.
|
# highest priority.
|
||||||
logger.info(
|
|
||||||
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
|
|
||||||
)
|
|
||||||
sorted_indices = sorted(
|
sorted_indices = sorted(
|
||||||
range(len(valid_backends_priorities)),
|
range(len(valid_backends_priorities)),
|
||||||
key=lambda i: valid_backends_priorities[i][1],
|
key=lambda i: valid_backends_priorities[i][1],
|
||||||
@ -417,8 +414,9 @@ class CudaPlatformBase(Platform):
|
|||||||
selected_index = sorted_indices[0]
|
selected_index = sorted_indices[0]
|
||||||
selected_backend = valid_backends_priorities[selected_index][0]
|
selected_backend = valid_backends_priorities[selected_index][0]
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using %s backend.",
|
"Using %s attention backend out of potential backends: %s",
|
||||||
selected_backend.name,
|
selected_backend.name,
|
||||||
|
[b[0].name for b in valid_backends_priorities],
|
||||||
)
|
)
|
||||||
|
|
||||||
return selected_backend.get_path()
|
return selected_backend.get_path()
|
||||||
|
|||||||
@ -42,7 +42,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||||
from vllm.transformers_utils.utils import (
|
from vllm.transformers_utils.utils import (
|
||||||
check_gguf_file,
|
check_gguf_file,
|
||||||
|
is_gguf,
|
||||||
|
is_remote_gguf,
|
||||||
parse_safetensors_file_metadata,
|
parse_safetensors_file_metadata,
|
||||||
|
split_remote_gguf,
|
||||||
)
|
)
|
||||||
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
@ -453,51 +456,55 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No
|
|||||||
|
|
||||||
def patch_rope_parameters(config: PretrainedConfig) -> None:
|
def patch_rope_parameters(config: PretrainedConfig) -> None:
|
||||||
"""Provide backwards compatibility for RoPE."""
|
"""Provide backwards compatibility for RoPE."""
|
||||||
# Retrieve rope_parameters differently based on Transformers version
|
# Patch rope_parameters differently based on Transformers version
|
||||||
if Version(version("transformers")) >= Version("5.0.0.dev0"):
|
if Version(version("transformers")) >= Version("5.0.0.dev0"):
|
||||||
from transformers.modeling_rope_utils import RopeParameters
|
from transformers.modeling_rope_utils import (
|
||||||
|
rope_config_validation,
|
||||||
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
|
standardize_rope_params,
|
||||||
config, "rope_parameters", None
|
|
||||||
)
|
)
|
||||||
elif hasattr(config, "rope_parameters"):
|
|
||||||
# We are in Transformers v4 and rope_parameters
|
# When Transformers v5 is installed, legacy rope_theta may be present
|
||||||
# has already been patched for this config
|
# when using custom code models written for Transformers v4
|
||||||
return
|
if (rope_theta := getattr(config, "rope_theta", None)) is not None:
|
||||||
|
standardize_rope_params(config, rope_theta=rope_theta)
|
||||||
|
rope_config_validation(config)
|
||||||
|
# Delete rope_theta to avoid confusion in downstream code
|
||||||
|
del config.rope_theta
|
||||||
else:
|
else:
|
||||||
# Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
|
# When Transformers v4 is installed, legacy rope_scaling may be present
|
||||||
rope_theta: float | None = getattr(config, "rope_theta", None)
|
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
|
||||||
rope_scaling: dict | None = getattr(config, "rope_scaling", None)
|
config.rope_parameters = rope_scaling
|
||||||
rope_parameters = rope_scaling
|
# When Transformers v4 is installed, legacy rope_theta may be present
|
||||||
# Move rope_theta into rope_parameters
|
if (rope_theta := getattr(config, "rope_theta", None)) is not None:
|
||||||
if rope_theta is not None:
|
if not hasattr(config, "rope_parameters"):
|
||||||
rope_parameters = rope_parameters or {"rope_type": "default"}
|
config.rope_parameters = {"rope_type": "default"}
|
||||||
rope_parameters["rope_theta"] = rope_theta
|
config.rope_parameters["rope_theta"] = rope_theta
|
||||||
# Add original_max_position_embeddings if present
|
|
||||||
if rope_parameters and (
|
|
||||||
ompe := getattr(config, "original_max_position_embeddings", None)
|
|
||||||
):
|
|
||||||
rope_parameters["original_max_position_embeddings"] = ompe
|
|
||||||
# Write back to config
|
|
||||||
config.rope_parameters = rope_parameters
|
|
||||||
|
|
||||||
# No RoPE parameters to patch
|
# No RoPE parameters to patch
|
||||||
if rope_parameters is None:
|
if not hasattr(config, "rope_parameters"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Add original_max_position_embeddings if present
|
||||||
|
if ompe := getattr(config, "original_max_position_embeddings", None):
|
||||||
|
config.rope_parameters["original_max_position_embeddings"] = ompe
|
||||||
|
|
||||||
# Handle nested rope_parameters in interleaved sliding attention models
|
# Handle nested rope_parameters in interleaved sliding attention models
|
||||||
if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
||||||
for rope_parameters_layer_type in rope_parameters.values():
|
for rope_parameters_layer_type in config.rope_parameters.values():
|
||||||
patch_rope_parameters_dict(rope_parameters_layer_type)
|
patch_rope_parameters_dict(rope_parameters_layer_type)
|
||||||
else:
|
else:
|
||||||
patch_rope_parameters_dict(rope_parameters)
|
patch_rope_parameters_dict(config.rope_parameters)
|
||||||
|
|
||||||
|
|
||||||
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
|
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
|
||||||
if "rope_type" in rope_parameters and "type" in rope_parameters:
|
if "rope_type" in rope_parameters and "type" in rope_parameters:
|
||||||
rope_type = rope_parameters["rope_type"]
|
rope_type = rope_parameters["rope_type"]
|
||||||
rope_type_legacy = rope_parameters["type"]
|
rope_type_legacy = rope_parameters["type"]
|
||||||
if rope_type != rope_type_legacy:
|
if (rope_type_legacy == "su" and rope_type == "longrope") or (
|
||||||
|
rope_type_legacy == "mrope" and rope_type == "default"
|
||||||
|
):
|
||||||
|
pass # No action needed
|
||||||
|
elif rope_type != rope_type_legacy:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Found conflicts between 'rope_type={rope_type}' (modern "
|
f"Found conflicts between 'rope_type={rope_type}' (modern "
|
||||||
f"field) and 'type={rope_type_legacy}' (legacy field). "
|
f"field) and 'type={rope_type_legacy}' (legacy field). "
|
||||||
@ -629,10 +636,12 @@ def maybe_override_with_speculators(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
|
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
|
||||||
"""
|
"""
|
||||||
is_gguf = check_gguf_file(model)
|
if check_gguf_file(model):
|
||||||
if is_gguf:
|
|
||||||
kwargs["gguf_file"] = Path(model).name
|
kwargs["gguf_file"] = Path(model).name
|
||||||
gguf_model_repo = Path(model).parent
|
gguf_model_repo = Path(model).parent
|
||||||
|
elif is_remote_gguf(model):
|
||||||
|
repo_id, _ = split_remote_gguf(model)
|
||||||
|
gguf_model_repo = Path(repo_id)
|
||||||
else:
|
else:
|
||||||
gguf_model_repo = None
|
gguf_model_repo = None
|
||||||
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||||
@ -678,10 +687,18 @@ def get_config(
|
|||||||
) -> PretrainedConfig:
|
) -> PretrainedConfig:
|
||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
|
|
||||||
is_gguf = check_gguf_file(model)
|
_is_gguf = is_gguf(model)
|
||||||
if is_gguf:
|
_is_remote_gguf = is_remote_gguf(model)
|
||||||
kwargs["gguf_file"] = Path(model).name
|
if _is_gguf:
|
||||||
model = Path(model).parent
|
if check_gguf_file(model):
|
||||||
|
# Local GGUF file
|
||||||
|
kwargs["gguf_file"] = Path(model).name
|
||||||
|
model = Path(model).parent
|
||||||
|
elif _is_remote_gguf:
|
||||||
|
# Remote GGUF - extract repo_id from repo_id:quant_type format
|
||||||
|
# The actual GGUF file will be downloaded later by GGUFModelLoader
|
||||||
|
# Keep model as repo_id:quant_type for download, but use repo_id for config
|
||||||
|
model, _ = split_remote_gguf(model)
|
||||||
|
|
||||||
if config_format == "auto":
|
if config_format == "auto":
|
||||||
try:
|
try:
|
||||||
@ -689,10 +706,25 @@ def get_config(
|
|||||||
# Transformers implementation.
|
# Transformers implementation.
|
||||||
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
|
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
|
||||||
config_format = "mistral"
|
config_format = "mistral"
|
||||||
elif is_gguf or file_or_path_exists(
|
elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
|
||||||
model, HF_CONFIG_NAME, revision=revision
|
model, HF_CONFIG_NAME, revision=revision
|
||||||
):
|
):
|
||||||
config_format = "hf"
|
config_format = "hf"
|
||||||
|
# Remote GGUF models must have config.json in repo,
|
||||||
|
# otherwise the config can't be parsed correctly.
|
||||||
|
# FIXME(Isotr0py): Support remote GGUF repos without config.json
|
||||||
|
elif _is_remote_gguf and not file_or_path_exists(
|
||||||
|
model, HF_CONFIG_NAME, revision=revision
|
||||||
|
):
|
||||||
|
err_msg = (
|
||||||
|
"Could not find config.json for remote GGUF model repo. "
|
||||||
|
"To load remote GGUF model through `<repo_id>:<quant_type>`, "
|
||||||
|
"ensure your model has config.json (HF format) file. "
|
||||||
|
"Otherwise please specify --hf-config-path <original_repo> "
|
||||||
|
"in engine args to fetch config from unquantized hf model."
|
||||||
|
)
|
||||||
|
logger.error(err_msg)
|
||||||
|
raise ValueError(err_msg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not detect config format for no config file found. "
|
"Could not detect config format for no config file found. "
|
||||||
@ -713,9 +745,6 @@ def get_config(
|
|||||||
"'config.json'.\n"
|
"'config.json'.\n"
|
||||||
" - For Mistral models: ensure the presence of a "
|
" - For Mistral models: ensure the presence of a "
|
||||||
"'params.json'.\n"
|
"'params.json'.\n"
|
||||||
"3. For GGUF: pass the local path of the GGUF checkpoint.\n"
|
|
||||||
" Loading GGUF from a remote repo directly is not yet "
|
|
||||||
"supported.\n"
|
|
||||||
).format(model=model)
|
).format(model=model)
|
||||||
|
|
||||||
raise ValueError(error_message) from e
|
raise ValueError(error_message) from e
|
||||||
@ -729,7 +758,7 @@ def get_config(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# Special architecture mapping check for GGUF models
|
# Special architecture mapping check for GGUF models
|
||||||
if is_gguf:
|
if _is_gguf:
|
||||||
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
||||||
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
||||||
@ -889,6 +918,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
|
|||||||
A dictionary containing the pooling type and whether
|
A dictionary containing the pooling type and whether
|
||||||
normalization is used, or None if no pooling configuration is found.
|
normalization is used, or None if no pooling configuration is found.
|
||||||
"""
|
"""
|
||||||
|
if is_remote_gguf(model):
|
||||||
|
model, _ = split_remote_gguf(model)
|
||||||
|
|
||||||
modules_file_name = "modules.json"
|
modules_file_name = "modules.json"
|
||||||
|
|
||||||
@ -1108,6 +1139,8 @@ def get_hf_image_processor_config(
|
|||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
if check_gguf_file(model):
|
if check_gguf_file(model):
|
||||||
model = Path(model).parent
|
model = Path(model).parent
|
||||||
|
elif is_remote_gguf(model):
|
||||||
|
model, _ = split_remote_gguf(model)
|
||||||
return get_image_processor_config(
|
return get_image_processor_config(
|
||||||
model, token=hf_token, revision=revision, **kwargs
|
model, token=hf_token, revision=revision, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
|
|||||||
from transformers.video_processing_utils import BaseVideoProcessor
|
from transformers.video_processing_utils import BaseVideoProcessor
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path
|
from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
|
||||||
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
|
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -236,8 +236,8 @@ def cached_processor_from_config(
|
|||||||
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
|
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _P:
|
) -> _P:
|
||||||
if check_gguf_file(model_config.model):
|
if is_gguf(model_config.model):
|
||||||
assert not check_gguf_file(model_config.tokenizer), (
|
assert not is_gguf(model_config.tokenizer), (
|
||||||
"For multimodal GGUF models, the original tokenizer "
|
"For multimodal GGUF models, the original tokenizer "
|
||||||
"should be used to correctly load processor."
|
"should be used to correctly load processor."
|
||||||
)
|
)
|
||||||
@ -350,8 +350,8 @@ def cached_image_processor_from_config(
|
|||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
if check_gguf_file(model_config.model):
|
if is_gguf(model_config.model):
|
||||||
assert not check_gguf_file(model_config.tokenizer), (
|
assert not is_gguf(model_config.tokenizer), (
|
||||||
"For multimodal GGUF models, the original tokenizer "
|
"For multimodal GGUF models, the original tokenizer "
|
||||||
"should be used to correctly load image processor."
|
"should be used to correctly load image processor."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -20,7 +20,12 @@ from vllm.transformers_utils.config import (
|
|||||||
list_filtered_repo_files,
|
list_filtered_repo_files,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import (
|
||||||
|
check_gguf_file,
|
||||||
|
is_gguf,
|
||||||
|
is_remote_gguf,
|
||||||
|
split_remote_gguf,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -180,10 +185,12 @@ def get_tokenizer(
|
|||||||
kwargs["truncation_side"] = "left"
|
kwargs["truncation_side"] = "left"
|
||||||
|
|
||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
is_gguf = check_gguf_file(tokenizer_name)
|
if is_gguf(tokenizer_name):
|
||||||
if is_gguf:
|
if check_gguf_file(tokenizer_name):
|
||||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||||
tokenizer_name = Path(tokenizer_name).parent
|
tokenizer_name = Path(tokenizer_name).parent
|
||||||
|
elif is_remote_gguf(tokenizer_name):
|
||||||
|
tokenizer_name, _ = split_remote_gguf(tokenizer_name)
|
||||||
|
|
||||||
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
|
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
|
||||||
# first to use official Mistral tokenizer if possible.
|
# first to use official Mistral tokenizer if possible.
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from os import PathLike
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from gguf import GGMLQuantizationType
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def is_remote_gguf(model: str | Path) -> bool:
|
||||||
|
"""Check if the model is a remote GGUF model."""
|
||||||
|
model = str(model)
|
||||||
|
return (
|
||||||
|
(not is_cloud_storage(model))
|
||||||
|
and (not model.startswith(("http://", "https://")))
|
||||||
|
and ("/" in model and ":" in model)
|
||||||
|
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
|
||||||
|
"""Check if the quant type is a valid GGUF quant type."""
|
||||||
|
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
|
||||||
|
"""Split the model into repo_id and quant type."""
|
||||||
|
model = str(model)
|
||||||
|
if is_remote_gguf(model):
|
||||||
|
parts = model.rsplit(":", 1)
|
||||||
|
return (parts[0], parts[1])
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
|
||||||
|
"- It should be in repo_id:quant_type format.\n"
|
||||||
|
"- Valid GGMLQuantizationType values: %s",
|
||||||
|
model,
|
||||||
|
GGMLQuantizationType._member_names_,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_gguf(model: str | Path) -> bool:
|
||||||
|
"""Check if the model is a GGUF model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name, path, or Path object to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model is a GGUF model, False otherwise.
|
||||||
|
"""
|
||||||
|
model = str(model)
|
||||||
|
|
||||||
|
# Check if it's a local GGUF file
|
||||||
|
if check_gguf_file(model):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if it's a remote GGUF model (repo_id:quant_type format)
|
||||||
|
return is_remote_gguf(model)
|
||||||
|
|
||||||
|
|
||||||
def modelscope_list_repo_files(
|
def modelscope_list_repo_files(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
|
|||||||
@ -73,14 +73,6 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
# Enable the deprecated kwarg for Python 3.12 and below
|
# Enable the deprecated kwarg for Python 3.12 and below
|
||||||
|
|
||||||
def parse_known_args(self, args=None, namespace=None):
|
def parse_known_args(self, args=None, namespace=None):
|
||||||
if args is not None and "--disable-log-requests" in args:
|
|
||||||
# Special case warning because the warning below won't trigger
|
|
||||||
# if –-disable-log-requests because its value is default.
|
|
||||||
logger.warning_once(
|
|
||||||
"argument '--disable-log-requests' is deprecated and "
|
|
||||||
"replaced with '--enable-log-requests'. This will be "
|
|
||||||
"removed in v0.12.0."
|
|
||||||
)
|
|
||||||
namespace, args = super().parse_known_args(args, namespace)
|
namespace, args = super().parse_known_args(args, namespace)
|
||||||
for action in FlexibleArgumentParser._deprecated:
|
for action in FlexibleArgumentParser._deprecated:
|
||||||
if (
|
if (
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user