mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 12:27:12 +08:00
Merge branch 'main' of github.com:vllm-project/vllm into conftest/generate_beam_search/simplify-return-value
This commit is contained in:
commit
876c0ed340
@ -71,6 +71,20 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# x86 CPU wheel build
|
||||
- label: "Build x86 CPU wheel"
|
||||
depends_on: ~
|
||||
id: build-wheel-x86-cpu
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# Build release images (12.9)
|
||||
- label: "Build release image (x86)"
|
||||
depends_on: ~
|
||||
|
||||
@ -326,10 +326,10 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
- label: V1 Test e2e + engine # 65min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -435,7 +435,7 @@ steps:
|
||||
|
||||
- label: Examples Test # 30min
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -455,7 +455,6 @@ steps:
|
||||
# for multi-modal models
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
# for pooling models
|
||||
|
||||
@ -692,6 +692,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Run a subset of model initialization tests
|
||||
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
|
||||
@ -704,6 +705,7 @@ steps:
|
||||
- vllm/model_executor/models/
|
||||
- vllm/transformers_utils/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
# subset of supported models (the complement of the small subset in the above
|
||||
@ -836,7 +838,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
@ -1346,6 +1348,7 @@ steps:
|
||||
- label: Prime-RL Integration Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
soft_fail: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
@ -1379,4 +1382,4 @@ steps:
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
|
||||
@ -384,7 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
PYTHONPATH=$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE marlin_generation_result
|
||||
OUTPUT_VARIABLE marlin_generation_result
|
||||
@ -822,7 +822,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
@ -1004,7 +1004,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
PYTHONPATH=$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
|
||||
@ -143,11 +143,13 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- IBM
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Red Hat
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
|
||||
@ -620,7 +620,7 @@ def get_tokenizer(
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MistralTokenizer requires vllm package.\n"
|
||||
|
||||
@ -99,7 +99,6 @@ def benchmark_mrope(
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
|
||||
@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
base = 10000
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
|
||||
@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
|
||||
run_python(_VLLM_TORCH_GOMP_PATH
|
||||
"
|
||||
import os, glob
|
||||
try:
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
torch_libs = os.path.join(site_root, 'torch.libs')
|
||||
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
|
||||
except:
|
||||
print('')
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
|
||||
# Search both torch.libs and torch/lib
|
||||
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
|
||||
candidates = []
|
||||
for root in roots:
|
||||
if not os.path.isdir(root):
|
||||
continue
|
||||
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
|
||||
|
||||
print(candidates[0] if candidates else '')
|
||||
"
|
||||
"failed to probe torch.libs for libgomp")
|
||||
"failed to probe for libgomp")
|
||||
|
||||
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
|
||||
return()
|
||||
|
||||
12
csrc/cache.h
12
csrc/cache.h
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@ -58,6 +59,15 @@ void cp_gather_cache(
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Gather and upconvert FP8 KV cache to BF16 workspace
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache(
|
||||
torch::Tensor& dst_k, // [num_tokens, head_dim]
|
||||
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
|
||||
const torch::Tensor& block_table, // [batch_size, num_blocks]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache(
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
|
||||
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
|
||||
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ seq_lens, // [BATCH]
|
||||
const int32_t* __restrict__ workspace_starts, // [BATCH]
|
||||
const int32_t block_size, const int32_t head_dim,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = workspace_starts[bid];
|
||||
const int32_t seq_len = seq_lens[bid];
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Process each token in this split
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
const uint8_t* token_ptr =
|
||||
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
|
||||
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
|
||||
|
||||
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
|
||||
const uint8_t* no_pe_ptr = token_ptr;
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const __nv_bfloat16* rope_ptr =
|
||||
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
|
||||
|
||||
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
|
||||
if (tid < 512) {
|
||||
// FP8 dequantization
|
||||
const int tile = tid >> 7; // each tile is 128 elements
|
||||
const float scale = scales_ptr[tile];
|
||||
const uint8_t val = no_pe_ptr[tid];
|
||||
dst_ptr[tid] =
|
||||
fp8::scaled_convert<__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
|
||||
} else if (tid < 576) {
|
||||
// Rope copy (64 bf16 elements)
|
||||
const int rope_idx = tid - 512;
|
||||
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
|
||||
}
|
||||
|
||||
// Move to next token
|
||||
offset += 1;
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
|
||||
// block_size.
|
||||
@ -1202,6 +1280,57 @@ void cp_gather_cache(
|
||||
}
|
||||
}
|
||||
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t head_dim = dst.size(1);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
|
||||
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
|
||||
"workspace_starts must be int32");
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == seq_lens.device(),
|
||||
"src_cache and seq_lens must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
||||
"src_cache and workspace_starts must be on the same device");
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
|
||||
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
||||
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
src_cache.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
dst_entry_stride);
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
|
||||
@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
|
||||
@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
}
|
||||
@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, bool SCALE_UE8M0>
|
||||
__device__ __forceinline__ float ComputeGroupScale(
|
||||
const T* __restrict__ group_input, T* __restrict__ smem_group,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float eps, const float max_8bit) {
|
||||
float local_absmax = eps;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
return y_s;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__device__ __forceinline__ void QuantizeGroup(
|
||||
const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float y_s, const float min_8bit, const float max_8bit) {
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = float;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
const float y_s = ComputeGroupScale<T, SCALE_UE8M0>(
|
||||
group_input, smem_group, group_size, lane_id, threads_per_group, eps,
|
||||
max_8bit);
|
||||
|
||||
scale_element_t y_s_quant = y_s;
|
||||
|
||||
@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
inline int GetGroupsPerBlock(int64_t num_groups) {
|
||||
if (num_groups % 16 == 0) {
|
||||
return 16;
|
||||
}
|
||||
if (num_groups % 8 == 0) {
|
||||
return 8;
|
||||
}
|
||||
if (num_groups % 4 == 0) {
|
||||
return 4;
|
||||
}
|
||||
if (num_groups % 2 == 0) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
const float y_s =
|
||||
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
|
||||
threads_per_group, eps, max_8bit);
|
||||
|
||||
// pack 4 scales into a uint32
|
||||
if (lane_id == 0) {
|
||||
@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
|
||||
@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
|
||||
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
|
||||
"batch_size) -> ()");
|
||||
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
|
||||
&cp_gather_and_upconvert_fp8_kv_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
|
||||
@ -24,11 +24,13 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- IBM
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Red Hat
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
|
||||
@ -100,7 +100,23 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm CPU images.
|
||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
|
||||
|
||||
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
|
||||
Please replace `<version>` in the command below with a specific version string (e.g., `0.12.0`).
|
||||
|
||||
```bash
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v<version>
|
||||
```
|
||||
|
||||
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.
|
||||
|
||||
The latest code can contain bugs and may not be stable. Please use it with caution.
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
@ -281,17 +281,27 @@ Alternatively, you can use the `openai` Python package:
|
||||
|
||||
Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications.
|
||||
|
||||
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
|
||||
If desired, you can also manually set the backend of your choice using the `--attention-backend` CLI argument:
|
||||
|
||||
```bash
|
||||
# For online serving
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --attention-backend FLASH_ATTN
|
||||
|
||||
# For offline inference
|
||||
python script.py --attention-backend FLASHINFER
|
||||
```
|
||||
|
||||
Some of the available backend options include:
|
||||
|
||||
- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
|
||||
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
|
||||
|
||||
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
|
||||
For AMD ROCm, you can further control the specific Attention implementation using the following options:
|
||||
|
||||
- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- Triton Prefill-Decode Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- AITER Multi-head Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=1`
|
||||
- Triton Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=0 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
- AITER Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
- Triton Prefill-Decode Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=true` as a CLI argument.
|
||||
- AITER Multi-head Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=1` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [docker/Dockerfile](../../docker/Dockerfile) for instructions on how to install it.
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
|
||||
@ -33,6 +33,7 @@ import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
@ -222,6 +223,11 @@ if __name__ == "__main__":
|
||||
|
||||
from multiprocessing import Process
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from multiprocessing import set_start_method
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# --worker \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<worker_node_ip>
|
||||
#
|
||||
#
|
||||
# Each worker requires a unique VLLM_HOST_IP value.
|
||||
# Keep each terminal session open. Closing a session stops the associated Ray
|
||||
# node and thereby shuts down the entire cluster.
|
||||
@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=...").
|
||||
VLLM_HOST_IP=""
|
||||
for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do
|
||||
arg="${ADDITIONAL_ARGS[$i]}"
|
||||
case "${arg}" in
|
||||
-e)
|
||||
next="${ADDITIONAL_ARGS[$((i + 1))]:-}"
|
||||
if [[ "${next}" == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${next#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
;;
|
||||
-eVLLM_HOST_IP=* | VLLM_HOST_IP=*)
|
||||
VLLM_HOST_IP="${arg#*=}"
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent.
|
||||
if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then
|
||||
if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then
|
||||
echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})."
|
||||
echo "Using VLLM_HOST_IP as the head node address."
|
||||
HEAD_NODE_ADDRESS="${VLLM_HOST_IP}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate a unique container name with random suffix.
|
||||
# Docker container names must be unique on each host.
|
||||
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
|
||||
@ -74,36 +102,17 @@ cleanup() {
|
||||
trap cleanup EXIT
|
||||
|
||||
# Build the Ray start command based on the node role.
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# while workers connect to the head's address.
|
||||
RAY_START_CMD="ray start --block"
|
||||
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||
RAY_START_CMD+=" --head --port=6379"
|
||||
RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379"
|
||||
else
|
||||
|
||||
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||
fi
|
||||
|
||||
# Parse VLLM_HOST_IP from additional args if present.
|
||||
# This is needed for multi-NIC configurations where Ray needs explicit IP bindings.
|
||||
VLLM_HOST_IP=""
|
||||
for arg in "${ADDITIONAL_ARGS[@]}"; do
|
||||
if [[ $arg == "-e" ]]; then
|
||||
continue
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}"
|
||||
fi
|
||||
if [[ $arg == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${arg#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Build Ray IP environment variables if VLLM_HOST_IP is set.
|
||||
# These variables ensure Ray binds to the correct network interface on multi-NIC systems.
|
||||
RAY_IP_VARS=()
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_IP_VARS=(
|
||||
-e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
-e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
)
|
||||
fi
|
||||
|
||||
# Launch the container with the assembled parameters.
|
||||
@ -118,6 +127,5 @@ docker run \
|
||||
--shm-size 10.24g \
|
||||
--gpus all \
|
||||
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||
"${RAY_IP_VARS[@]}" \
|
||||
"${ADDITIONAL_ARGS[@]}" \
|
||||
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
|
||||
|
||||
@ -50,4 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
@ -1,2 +1,2 @@
|
||||
lmcache >= 0.3.10.post1
|
||||
lmcache
|
||||
nixl >= 0.7.1 # Required for disaggregated prefill
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||
is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
@ -138,6 +138,17 @@ elif current_platform.is_rocm():
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
def has_cuda_graph_wrapper_metadata() -> bool:
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module = import_module("torch._inductor.utils")
|
||||
module.CUDAGraphWrapperMetadata # noqa B018
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
@ -145,7 +156,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"inductor_graph_partition",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not has_cuda_graph_wrapper_metadata(),
|
||||
reason="This test requires"
|
||||
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
|
||||
),
|
||||
),
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
|
||||
@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
def __init__(self, head_dim=64, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
@ -200,6 +200,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_init():
|
||||
"""Initialize the workspace manager for tests that need it.
|
||||
|
||||
This fixture initializes the workspace manager with a CUDA device
|
||||
if available, and resets it after the test completes. Tests that
|
||||
create a full vLLM engine should NOT use this fixture as the engine
|
||||
will initialize the workspace manager itself.
|
||||
"""
|
||||
from vllm.v1.worker.workspace import (
|
||||
init_workspace_manager,
|
||||
reset_workspace_manager,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
init_workspace_manager(device)
|
||||
yield
|
||||
reset_workspace_manager()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def dynamo_reset():
|
||||
yield
|
||||
@ -679,10 +700,16 @@ class HfRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Encoder-decoder models return decoder_hidden_states instead of
|
||||
# hidden_states
|
||||
hidden_states = (
|
||||
getattr(output, "hidden_states", None) or output.decoder_hidden_states
|
||||
)
|
||||
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||
) = self._hidden_states_to_logprobs(hidden_states, num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
|
||||
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
quant_config = None
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
|
||||
nvfp4_fused_moe.create_weights(
|
||||
fml,
|
||||
test_config.num_local_experts,
|
||||
test_config.hidden_size,
|
||||
test_config.intermediate_size,
|
||||
params_dtype=torch.uint8,
|
||||
global_num_experts=test_config.num_experts,
|
||||
)
|
||||
|
||||
fml = fml.to(device)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
test_config.num_local_experts,
|
||||
test_config.intermediate_size,
|
||||
test_config.hidden_size,
|
||||
in_dtype=torch.bfloat16,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
fml.w13_weight.data = w1_q
|
||||
fml.w2_weight.data = w2_q
|
||||
|
||||
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
|
||||
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
|
||||
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
|
||||
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
|
||||
fml.w2_weight_scale.data = (
|
||||
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w2_weight_scale.data.dtype)
|
||||
fml.w13_weight_scale.data = (
|
||||
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w13_weight_scale.data.dtype)
|
||||
|
||||
nvfp4_fused_moe.process_weights_after_loading(fml)
|
||||
|
||||
fml.maybe_init_modular_kernel()
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_dp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
hidden_states = []
|
||||
router_logits = []
|
||||
for layer_idx in range(test_config.num_layers):
|
||||
hidden_states.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
router_logits.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.num_experts),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
out_before_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_before_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
logical_to_physical_map_list = []
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
physical_to_logical_map = shuffled_indices[lidx].to(device)
|
||||
logical_to_physical_map = torch.empty(
|
||||
(num_global_experts,), dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map[physical_to_logical_map] = torch.arange(
|
||||
0, num_global_experts, dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map_list.append(
|
||||
logical_to_physical_map.reshape(num_global_experts, 1)
|
||||
)
|
||||
|
||||
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
|
||||
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
logical_replica_count = torch.ones(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
fml.enable_eplb = True
|
||||
fml.set_eplb_state(
|
||||
lidx,
|
||||
torch.zeros(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
out_after_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_after_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
for lidx in range(test_config.num_layers):
|
||||
torch.testing.assert_close(
|
||||
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("num_layers", [8])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("backend", ["latency", "throughput"])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_tokens: int,
|
||||
backend: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
@ -1,21 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai_harmony import Author, Message, Role, TextContent
|
||||
|
||||
from tests.entrypoints.openai.utils import verify_harmony_messages
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
auto_drop_analysis_messages,
|
||||
get_encoding,
|
||||
has_custom_tools,
|
||||
parse_chat_input_to_harmony_message,
|
||||
parse_chat_output,
|
||||
parse_input_to_harmony_message,
|
||||
parse_output_message,
|
||||
)
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""Tests for parse_input_to_harmony_message function."""
|
||||
class TestCommonParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are common to both Chat Completion
|
||||
parse_chat_input_to_harmony_message and Responsees API
|
||||
parse_input_to_harmony_message functions.
|
||||
"""
|
||||
|
||||
def test_assistant_message_with_tool_calls(self):
|
||||
@pytest.fixture(
|
||||
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
|
||||
)
|
||||
def parse_function(self, request):
|
||||
return request.param
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, parse_function):
|
||||
"""Test parsing assistant message with tool calls."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
||||
@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
|
||||
assert messages[1].recipient == "functions.search_web"
|
||||
assert messages[1].content_type == "json"
|
||||
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self):
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
|
||||
"""Test parsing assistant message with tool call having None arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].recipient == "functions.get_current_time"
|
||||
|
||||
def test_system_message(self, parse_function):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self, parse_function):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self, parse_function):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self, parse_function):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self, parse_function):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self, parse_function):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self, parse_function):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_array_content_with_missing_text(self, parse_function):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Responses API
|
||||
parse_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
"""Test parsing tool message with string content."""
|
||||
chat_msg = {
|
||||
@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.search_results"
|
||||
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.empty_tool"
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_system_message(self):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
class TestParseChatInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Chat Completion API
|
||||
parse_chat_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
def test_user_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
def test_user_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_content_but_empty_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"reasoning": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": "The answer is 4.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_but_no_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_array_content_with_missing_text(self):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
def test_assistant_message_with_tool_calls_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "get_weather",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.get_weather",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_array_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "search_results",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
{"type": "text", "text": "Result 1: "},
|
||||
{"type": "text", "text": "Result 2: "},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "http://example.com/img.png",
|
||||
}, # Should be ignored
|
||||
{"type": "text", "text": "Result 3"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.search_results",
|
||||
"content": "Result 1: Result 2: Result 3",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_none_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestAutoDropAnalysisMessages:
|
||||
def test_no_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_analysis_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_multiple_analysis_messages_without_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_drops_one_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
def test_drops_all_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first 3 analysis messages
|
||||
assert cleaned_messages == messages[3:]
|
||||
|
||||
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 5."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped all those analysis messages
|
||||
assert len(cleaned_messages) == 2
|
||||
assert cleaned_messages[0].content[0].text == "The answer is 4."
|
||||
assert cleaned_messages[1].content[0].text == "The answer is 5."
|
||||
|
||||
def test_drops_non_assistant_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.TOOL, "The tool thinks we should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
|
||||
class TestParseChatOutput:
|
||||
def test_parse_chat_output_interrupted_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm in the middle of thinking"
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
|
||||
"<|start|>assistant<|channel|>final"
|
||||
"<|message|>I'm in the middle of answering"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm thinking."
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_complete_content(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
def test_parse_chat_output_complete_commentary(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I need to call some tools."
|
||||
|
||||
def test_parse_chat_output_complete_reasoning(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
|
||||
class TestParseOutputMessage:
|
||||
|
||||
@ -79,9 +79,12 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
|
||||
|
||||
assert chunk_count > 0
|
||||
assert first_chunk is not None, "message_start chunk was never observed"
|
||||
assert first_chunk.usage is not None, "first chunk should include usage stats"
|
||||
assert first_chunk.usage["output_tokens"] == 0
|
||||
assert first_chunk.usage["input_tokens"] > 5
|
||||
assert first_chunk.message is not None, "first chunk should include message"
|
||||
assert first_chunk.message.usage is not None, (
|
||||
"first chunk should include usage stats"
|
||||
)
|
||||
assert first_chunk.message.usage.output_tokens == 0
|
||||
assert first_chunk.message.usage.input_tokens > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -11,13 +11,25 @@ import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .utils import (
|
||||
accumulate_streaming_response,
|
||||
verify_chat_response,
|
||||
verify_harmony_messages,
|
||||
)
|
||||
|
||||
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
|
||||
# Verify that data_parallel_rank defaults to None
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
|
||||
|
||||
|
||||
class TestServingChatWithHarmony:
|
||||
"""
|
||||
These tests ensure Chat Completion requests are being properly converted into
|
||||
Harmony messages and Harmony response messages back into Chat Completion responses.
|
||||
These tests are not exhaustive, but each one was created to cover a specific case
|
||||
that we got wrong but is now fixed.
|
||||
|
||||
Any changes to the tests and their expectations may result in changes to the
|
||||
accuracy of model prompting and responses generated. It is suggested to run
|
||||
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
|
||||
any impact of changes in how we prompt Harmony models.
|
||||
"""
|
||||
|
||||
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
|
||||
def stream(self, request) -> bool:
|
||||
"""Parameterize tests to run in both non-streaming and streaming modes."""
|
||||
return request.param
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_engine(self) -> AsyncLLM:
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
return mock_engine
|
||||
|
||||
@pytest.fixture()
|
||||
def serving_chat(self, mock_engine) -> OpenAIServingChat:
|
||||
chat = _build_serving_chat(mock_engine)
|
||||
chat.use_harmony = True
|
||||
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
|
||||
return chat
|
||||
|
||||
def mock_request_output_from_req_and_token_ids(
|
||||
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
|
||||
) -> RequestOutput:
|
||||
# Our tests don't use most fields, so just get the token ids correct
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
return RequestOutput(
|
||||
request_id=req.request_id,
|
||||
prompt=[],
|
||||
prompt_token_ids=[],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tools(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def weather_messages_start(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?",
|
||||
},
|
||||
]
|
||||
|
||||
async def generate_response_from_harmony_str(
|
||||
self,
|
||||
serving_chat: OpenAIServingChat,
|
||||
req: ChatCompletionRequest,
|
||||
harmony_str: str,
|
||||
stream: bool = False,
|
||||
) -> ChatCompletionResponse:
|
||||
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
|
||||
async def result_generator():
|
||||
if stream:
|
||||
for token_id in harmony_token_ids:
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, [token_id]
|
||||
)
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, [], finished=True
|
||||
)
|
||||
else:
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, harmony_token_ids, finished=True
|
||||
)
|
||||
|
||||
generator_func = (
|
||||
serving_chat.chat_completion_stream_generator
|
||||
if stream
|
||||
else serving_chat.chat_completion_full_generator
|
||||
)
|
||||
|
||||
result = generator_func(
|
||||
request=req,
|
||||
result_generator=result_generator(),
|
||||
request_id=req.request_id,
|
||||
model_name=req.model,
|
||||
conversation=[],
|
||||
tokenizer=get_tokenizer(req.model),
|
||||
request_metadata=RequestResponseMetadata(
|
||||
request_id=req.request_id,
|
||||
model_name=req.model,
|
||||
),
|
||||
)
|
||||
|
||||
if stream:
|
||||
return await accumulate_streaming_response(result)
|
||||
return await result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_chat(self, serving_chat, stream):
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "We need to think really hard about this."
|
||||
final_str = "The answer is 2."
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
# The analysis message should be dropped on subsequent inputs because
|
||||
# of the subsequent assistant message to the final channel.
|
||||
{"role": "assistant", "channel": "final", "content": final_str},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_response_with_content(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
commentary_str = "We'll call get_weather."
|
||||
tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
content=commentary_str,
|
||||
tool_calls=[("get_weather", tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": commentary_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_and_reasoning(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_tools_and_reasoning(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
paris_tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", paris_tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": paris_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the second turn's output
|
||||
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
|
||||
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
|
||||
response_2 = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req_2, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(response_2, content=paris_weather_str)
|
||||
|
||||
# Add the output messages from the second turn as input to the third turn
|
||||
for choice in response_2.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add a new user message for the third turn
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today?",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the third turn's input
|
||||
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3)
|
||||
verify_harmony_messages(
|
||||
input_messages_3,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": paris_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": paris_weather_str,
|
||||
},
|
||||
{"role": "user", "content": messages[-1]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the third turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
boston_tool_args_str = '{"location": "Boston"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
|
||||
)
|
||||
response_3 = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response_3,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", boston_tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response_3.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the third turn as input to the fourth turn
|
||||
for choice in response_3.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "10 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the fourth turn's input
|
||||
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4)
|
||||
verify_harmony_messages(
|
||||
input_messages_4,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{"role": "assistant"},
|
||||
{"role": "tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": boston_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "10 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": "4",
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
# The reasoning that would have resulted in an analysis message is
|
||||
# dropped because of a later assistant message to the final channel.
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": messages[1]["content"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning_empty_content(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": "",
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": messages[1]["reasoning"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": [],
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": messages[1]["reasoning"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@ -10,7 +10,7 @@ import pytest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
190
tests/entrypoints/openai/utils.py
Normal file
190
tests/entrypoints/openai/utils.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
async def accumulate_streaming_response(
|
||||
stream_generator: AsyncGenerator[str, None],
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
|
||||
|
||||
This helper parses the SSE format and builds up the complete response
|
||||
by combining all the delta chunks.
|
||||
"""
|
||||
accumulated_content = ""
|
||||
accumulated_reasoning = None
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
role = None
|
||||
finish_reason = None
|
||||
response_id = None
|
||||
created = None
|
||||
model = None
|
||||
index = 0
|
||||
|
||||
async for chunk_str in stream_generator:
|
||||
# Skip empty lines and [DONE] marker
|
||||
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
|
||||
continue
|
||||
|
||||
# Parse SSE format: "data: {json}\n\n"
|
||||
if chunk_str.startswith("data: "):
|
||||
json_str = chunk_str[6:].strip()
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
|
||||
chunk = ChatCompletionStreamResponse(**chunk_data)
|
||||
|
||||
# Store metadata from first chunk
|
||||
if response_id is None:
|
||||
response_id = chunk.id
|
||||
created = chunk.created
|
||||
model = chunk.model
|
||||
|
||||
# Process each choice in the chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.delta.role:
|
||||
role = choice.delta.role
|
||||
if choice.delta.content:
|
||||
accumulated_content += choice.delta.content
|
||||
if choice.delta.reasoning:
|
||||
if accumulated_reasoning is None:
|
||||
accumulated_reasoning = ""
|
||||
accumulated_reasoning += choice.delta.reasoning
|
||||
if choice.delta.tool_calls:
|
||||
# Accumulate tool calls
|
||||
for tool_call_delta in choice.delta.tool_calls:
|
||||
# Find or create the tool call at this index
|
||||
while len(accumulated_tool_calls) <= tool_call_delta.index:
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
if tool_call_delta.id:
|
||||
accumulated_tool_calls[tool_call_delta.index]["id"] = (
|
||||
tool_call_delta.id
|
||||
)
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
accumulated_tool_calls[tool_call_delta.index][
|
||||
"function"
|
||||
]["name"] += tool_call_delta.function.name
|
||||
if tool_call_delta.function.arguments:
|
||||
accumulated_tool_calls[tool_call_delta.index][
|
||||
"function"
|
||||
]["arguments"] += tool_call_delta.function.arguments
|
||||
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if choice.index is not None:
|
||||
index = choice.index
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Build the final message
|
||||
message_kwargs = {
|
||||
"role": role or "assistant",
|
||||
"content": accumulated_content if accumulated_content else None,
|
||||
"reasoning": accumulated_reasoning,
|
||||
}
|
||||
|
||||
# Only include tool_calls if there are any
|
||||
if accumulated_tool_calls:
|
||||
message_kwargs["tool_calls"] = [
|
||||
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
|
||||
for tc in accumulated_tool_calls
|
||||
]
|
||||
|
||||
message = ChatMessage(**message_kwargs)
|
||||
|
||||
# Build the final response
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
|
||||
# Create usage info (with dummy values for tests)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=response_id or "chatcmpl-test",
|
||||
object="chat.completion",
|
||||
created=created or 0,
|
||||
model=model or "test-model",
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def verify_harmony_messages(
|
||||
messages: list[Any], expected_messages: list[dict[str, Any]]
|
||||
):
|
||||
assert len(messages) == len(expected_messages)
|
||||
for msg, expected in zip(messages, expected_messages):
|
||||
if "role" in expected:
|
||||
assert msg.author.role == expected["role"]
|
||||
if "author_name" in expected:
|
||||
assert msg.author.name == expected["author_name"]
|
||||
if "channel" in expected:
|
||||
assert msg.channel == expected["channel"]
|
||||
if "recipient" in expected:
|
||||
assert msg.recipient == expected["recipient"]
|
||||
if "content" in expected:
|
||||
assert msg.content[0].text == expected["content"]
|
||||
if "content_type" in expected:
|
||||
assert msg.content_type == expected["content_type"]
|
||||
if "tool_definitions" in expected:
|
||||
# Check that the tool definitions match the expected list of tool names
|
||||
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
|
||||
assert actual_tools == expected["tool_definitions"]
|
||||
|
||||
|
||||
def verify_chat_response(
|
||||
response: ChatCompletionResponse,
|
||||
content: str | None = None,
|
||||
reasoning: str | None = None,
|
||||
tool_calls: list[tuple[str, str]] | None = None,
|
||||
):
|
||||
assert len(response.choices) == 1
|
||||
message = response.choices[0].message
|
||||
|
||||
if content is not None:
|
||||
assert message.content == content
|
||||
else:
|
||||
assert not message.content
|
||||
|
||||
if reasoning is not None:
|
||||
assert message.reasoning == reasoning
|
||||
else:
|
||||
assert not message.reasoning
|
||||
|
||||
if tool_calls:
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) == len(tool_calls)
|
||||
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
|
||||
assert tc.function.name == expected_name
|
||||
assert tc.function.arguments == expected_args
|
||||
else:
|
||||
assert not message.tool_calls
|
||||
@ -29,7 +29,8 @@ from vllm.multimodal.utils import (
|
||||
encode_image_base64,
|
||||
encode_video_base64,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer, get_tokenizer
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
@ -796,9 +797,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
@ -825,10 +830,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
# Should have audio in mm_data as None (UUID provided)
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert mm_data["audio"] is None
|
||||
assert isinstance(mm_data["audio"], list)
|
||||
assert len(mm_data["audio"]) == 1
|
||||
assert mm_data["audio"][0] is None
|
||||
|
||||
# UUID should be recorded
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||
|
||||
|
||||
@ -1121,10 +1127,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
mm_data = await mm_future
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that empty dictionary for image_embeds is handled without errors."""
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_embeds", "image_embeds": {}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains an empty dictionary of embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == 0
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that multiple dictionaries for image_embeds is handled without errors."""
|
||||
# Create two sample image embedding tensors
|
||||
batch_size = 2
|
||||
image_embedding_1 = torch.randn(batch_size, 256, 1024)
|
||||
image_embedding_2 = torch.randn(batch_size, 3)
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embedding_1": tensor2base64(p),
|
||||
"image_embedding_2": tensor2base64(i),
|
||||
},
|
||||
}
|
||||
for p, i in zip(image_embedding_1, image_embedding_2)
|
||||
]
|
||||
+ [
|
||||
{"type": "text", "text": "Describe these two images."},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains a dictionary of multi-embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == batch_size
|
||||
|
||||
# Verify each embedding has the correct shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
|
||||
@ -32,8 +32,8 @@ def cal_diff(
|
||||
|
||||
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
||||
"Cutlass MLA Requires compute capability of 10 or above."
|
||||
if not current_platform.is_device_capability(100)
|
||||
"Cutlass MLA Requires compute capability of 100 or above."
|
||||
if not current_platform.is_device_capability_family(100)
|
||||
else "Cutlass MLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 1e-1, 2e-1
|
||||
rtol, atol = 3e-1, 4e-1
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -22,6 +23,10 @@ QDTYPES = (
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
# 0: use 2D kernel for decode
|
||||
# 8: use 3D kernel for decode
|
||||
SEQ_THRESHOLD_3D_VALUES = [0, 8]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
@ -92,6 +97,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attn(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
@ -103,6 +109,7 @@ def test_triton_unified_attn(
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
seq_threshold_3D: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -152,6 +159,21 @@ def test_triton_unified_attn(
|
||||
k_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
|
||||
num_par_softmax_segments = 16
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
unified_attention(
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
@ -169,6 +191,11 @@ def test_triton_unified_attn(
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
|
||||
@ -116,7 +116,6 @@ def test_mrope(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
|
||||
@ -83,8 +83,12 @@ def test_rotary_embedding(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
@ -150,9 +154,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
@ -177,9 +181,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
|
||||
@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128]
|
||||
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
def test_batched_deepgemm_vs_triton(
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
|
||||
):
|
||||
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||
|
||||
|
||||
@ -248,6 +248,7 @@ def test_fused_moe_batched_experts(
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
|
||||
@ -137,7 +137,7 @@ def setup_cuda():
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8(
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe(
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -342,6 +343,9 @@ def _deep_ep_moe(
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode"
|
||||
@ -437,6 +441,7 @@ def test_deep_ep_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = True
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
|
||||
@ -206,6 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
|
||||
@ -51,7 +51,14 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
|
||||
@ -269,7 +269,7 @@ class Case:
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
@ -77,6 +78,10 @@ def rank_worker(
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
# Initialize workspace manager in child process
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
|
||||
@ -209,6 +209,7 @@ def test_oai_triton_moe(
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
unfused: bool,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
(
|
||||
|
||||
@ -231,6 +231,7 @@ def test_fused_moe(
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
|
||||
@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
|
||||
@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -181,6 +182,7 @@ def test_fused_moe_batched_experts(
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
@ -863,6 +865,9 @@ def _pplx_test_loop(
|
||||
make_weights: bool,
|
||||
test_fn: Callable,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
|
||||
@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant(
|
||||
if quant_dtype == torch.int8
|
||||
else torch.finfo(quant_dtype)
|
||||
)
|
||||
qtype_traits_max = (
|
||||
ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else qtype_traits.max
|
||||
)
|
||||
qtype_traits_min = (
|
||||
-ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else qtype_traits.min
|
||||
use_fp8fnuz = (
|
||||
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
|
||||
)
|
||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
|
||||
@ -18,7 +18,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -62,7 +62,7 @@ def test_quantfp8_group_functionality(
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@ -71,7 +71,7 @@ def test_quantfp8_group_functionality(
|
||||
assert scales_cuda.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# Verify CUDA/native consistency
|
||||
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
|
||||
torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
|
||||
|
||||
# Quantized values should mostly match
|
||||
diff_count = (x_quant_cuda != x_quant_native).sum().item()
|
||||
|
||||
@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ScaledMM kernel selection logic (CPU-only)
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_is_supported_is_abstract():
|
||||
"""Test that is_supported() is properly defined as abstract."""
|
||||
assert issubclass(ScaledMMLinearKernel, ABC)
|
||||
assert hasattr(ScaledMMLinearKernel, "is_supported")
|
||||
|
||||
|
||||
def test_cpu_kernel_implements_is_supported():
|
||||
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
|
||||
"CPUScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
|
||||
CPUScaledMMLinearKernel.is_supported
|
||||
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
# Verify it can be called as a classmethod
|
||||
result, reason = CPUScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
|
||||
|
||||
def test_aiter_kernel_implements_is_supported():
|
||||
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
|
||||
"AiterScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(
|
||||
AiterScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
|
||||
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
)
|
||||
# Verify it can be called as a classmethod
|
||||
# (will return False on CPU, which is expected)
|
||||
result, reason = AiterScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
# On CPU, it should return False with a reason about requiring ROCm
|
||||
# This validates the method works correctly even on non-ROCm platforms
|
||||
|
||||
|
||||
def test_cpu_kernel_accepts_all_configs():
|
||||
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
|
||||
configs = [
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=False,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=True,
|
||||
),
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=True,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=False,
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
|
||||
assert can_impl, (
|
||||
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
)
|
||||
@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolParser,
|
||||
)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ....conftest import AudioTestAssets
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@ -1,150 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import librosa
|
||||
import pytest
|
||||
from transformers import AutoModelForSpeechSeq2Seq
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
from ....conftest import HfRunner, PromptAudioInput, VllmRunner
|
||||
from ....utils import create_new_process_for_each_test, multi_gpu_test
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
PROMPTS = [
|
||||
{
|
||||
"prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("winning_call").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
||||
},
|
||||
]
|
||||
VLLM_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
|
||||
HF_PROMPT = ""
|
||||
# Whisper expects 16kHz audio
|
||||
WHISPER_SAMPLE_RATE = 16000
|
||||
|
||||
EXPECTED = {
|
||||
"openai/whisper-tiny": [
|
||||
" He has birth words I spoke in the original corner of that. And a"
|
||||
" little piece of black coat poetry. Mary had a little sandwich,"
|
||||
" sweet, with white and snow. And everyone had it very went the last"
|
||||
" would sure to go.",
|
||||
" >> And the old one, fit John the way to Edgar Martinez. >> One more"
|
||||
" to line down the field line for our base camp. Here comes joy. Here"
|
||||
" is June and the third base. They're going to wave him in. The throw"
|
||||
" to the plate will be late. The Mariners are going to play for the"
|
||||
" American League Championship. I don't believe it. It just continues"
|
||||
" by all five.",
|
||||
],
|
||||
"openai/whisper-small": [
|
||||
" The first words I spoke in the original pornograph. A little piece"
|
||||
" of practical poetry. Mary had a little lamb, its fleece was quite a"
|
||||
" slow, and everywhere that Mary went the lamb was sure to go.",
|
||||
" And the old one pitch on the way to Edgar Martinez one month. Here"
|
||||
" comes joy. Here is Junior to third base. They're gonna wave him"
|
||||
" in. The throw to the plate will be late. The Mariners are going to"
|
||||
" play for the American League Championship. I don't believe it. It"
|
||||
" just continues. My, oh my.",
|
||||
],
|
||||
"openai/whisper-medium": [
|
||||
" The first words I spoke in the original phonograph, a little piece"
|
||||
" of practical poetry. Mary had a little lamb, its fleece was quite as"
|
||||
" slow, and everywhere that Mary went the lamb was sure to go.",
|
||||
" And the 0-1 pitch on the way to Edgar Martinez swung on the line"
|
||||
" down the left field line for Obeyshev. Here comes Joy. Here is"
|
||||
" Jorgen at third base. They're going to wave him in. The throw to the"
|
||||
" plate will be late. The Mariners are going to play for the American"
|
||||
" League Championship. I don't believe it. It just continues. My, oh"
|
||||
" my.",
|
||||
],
|
||||
"openai/whisper-large-v3": [
|
||||
" The first words I spoke in the original phonograph, a little piece"
|
||||
" of practical poetry. Mary had a little lamb, its feet were quite as"
|
||||
" slow, and everywhere that Mary went, the lamb was sure to go.",
|
||||
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line."
|
||||
" Now the left field line for a base hit. Here comes Joy. Here is"
|
||||
" Junior to third base. They're going to wave him in. The throw to the"
|
||||
" plate will be late. The Mariners are going to play for the American"
|
||||
" League Championship. I don't believe it. It just continues. My, oh,"
|
||||
" my.",
|
||||
],
|
||||
"openai/whisper-large-v3-turbo": [
|
||||
" The first words I spoke in the original phonograph, a little piece"
|
||||
" of practical poetry. Mary had a little lamb, its streets were quite"
|
||||
" as slow, and everywhere that Mary went the lamb was sure to go.",
|
||||
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line"
|
||||
" down the left field line for a base hit. Here comes Joy. Here is"
|
||||
" Junior to third base. They're going to wave him in. The throw to the"
|
||||
" plate will be late. The Mariners are going to play for the American"
|
||||
" League Championship. I don't believe it. It just continues. My, oh,"
|
||||
" my.",
|
||||
],
|
||||
}
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def use_spawn_for_whisper(monkeypatch):
|
||||
"""Whisper has issues with forked workers, use spawn instead."""
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
inputs: Sequence[tuple[list[str], list[str], PromptAudioInput]],
|
||||
model: str,
|
||||
*,
|
||||
max_model_len: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
dtype: str = "half",
|
||||
enforce_eager: bool = True,
|
||||
) -> None:
|
||||
prompt_list = PROMPTS * 10
|
||||
expected_list = EXPECTED[model] * 10
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the audio fixtures for the test are from AudioAsset.
|
||||
For huggingface runner, we provide the audio as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
"""
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
max_model_len=max_model_len,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
# TODO (NickLucche) figure out output differences with non-eager and re-enable
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"audio": 2},
|
||||
enforce_eager=enforce_eager,
|
||||
disable_custom_all_reduce=True,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.llm
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
vllm_prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=audios,
|
||||
)
|
||||
for vllm_prompts, _, audios in inputs
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
max_tokens=200,
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
hf_prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=audios,
|
||||
)
|
||||
for _, hf_prompts, audios in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompt_list, sampling_params)
|
||||
|
||||
for output, expected in zip(outputs, expected_list):
|
||||
print(output.outputs[0].text)
|
||||
assert output.outputs[0].text == expected
|
||||
@pytest.fixture
|
||||
def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
inputs = []
|
||||
for asset in audio_assets:
|
||||
audio, orig_sr = asset.audio_and_sample_rate
|
||||
# Resample to Whisper's expected sample rate (16kHz)
|
||||
if orig_sr != WHISPER_SAMPLE_RATE:
|
||||
audio = librosa.resample(
|
||||
audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
|
||||
)
|
||||
# vLLM prompts, HF prompts, audio inputs
|
||||
inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)]))
|
||||
return inputs
|
||||
|
||||
|
||||
def check_model_available(model: str) -> None:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@create_new_process_for_each_test()
|
||||
def test_models(vllm_runner, model, dtype) -> None:
|
||||
run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||
# @create_new_process_for_each_test() does not work for some runners
|
||||
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
num_logprobs: int,
|
||||
input_audios,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
check_model_available(model)
|
||||
if current_platform.is_cpu() and not enforce_eager:
|
||||
pytest.skip("Skipping test for CPU with non-eager mode")
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_audios,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
max_tokens=200,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
|
||||
@ -152,15 +148,31 @@ def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [200])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_models_distributed(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
distributed_executor_backend,
|
||||
model: str,
|
||||
distributed_executor_backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
input_audios,
|
||||
) -> None:
|
||||
check_model_available(model)
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_audios,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=False,
|
||||
)
|
||||
|
||||
@ -22,11 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.tokenizers import (
|
||||
MistralTokenizer,
|
||||
TokenizerLike,
|
||||
cached_tokenizer_from_config,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import (
|
||||
|
||||
42
tests/models/multimodal/processing/test_gemma3.py
Normal file
42
tests/models/multimodal/processing/test_gemma3.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
|
||||
def test_get_image_size_with_most_features(
|
||||
image_assets: ImageTestAssets, model_id: str
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
hf_processor_mm_kwargs: dict[str, object] = {}
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
max_image_size = processor.info.get_image_size_with_most_features()
|
||||
max_tokens = processor.info.get_num_image_tokens(
|
||||
image_width=max_image_size.width,
|
||||
image_height=max_image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
prompt = "<start_of_image>"
|
||||
image_seq_length = hf_processor.image_seq_length
|
||||
|
||||
for asset in image_assets:
|
||||
mm_data = {"image": [asset.pil_image]}
|
||||
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
mm_kwargs_data = processed_inputs["mm_kwargs"].get_data()
|
||||
num_patches_tensor = mm_kwargs_data["num_patches"]
|
||||
tokens = int(num_patches_tensor.item()) * image_seq_length
|
||||
assert tokens <= max_tokens
|
||||
@ -53,3 +53,38 @@ def test_processor_override(
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
|
||||
assert pixel_shape[1] == expected_pixels_shape[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
||||
@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28])
|
||||
def test_get_image_size_with_most_features(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
max_pixels: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={"max_pixels": max_pixels},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
hf_processor_mm_kwargs: dict[str, object] = {}
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size
|
||||
|
||||
max_image_size = processor.info.get_image_size_with_most_features()
|
||||
max_tokens = processor.info.get_num_image_tokens(
|
||||
image_width=max_image_size.width,
|
||||
image_height=max_image_size.height,
|
||||
image_processor=hf_processor.image_processor,
|
||||
)
|
||||
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
for asset in image_assets:
|
||||
mm_data = {"image": [asset.pil_image]}
|
||||
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist()
|
||||
t, h, w = grid_thw[0]
|
||||
tokens = (t * h * w) // (merge_size**2)
|
||||
assert tokens < max_tokens
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from ....utils import create_new_process_for_each_test
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
from .test_common import get_model_ids_to_test, get_text_token_prompts
|
||||
@ -136,6 +138,7 @@ def create_batched_mm_kwargs(
|
||||
)
|
||||
|
||||
|
||||
# TODO(Isotr0py): Don't initalize model during test
|
||||
@contextmanager
|
||||
def initialize_dummy_model(
|
||||
model_cls: type[nn.Module],
|
||||
@ -150,16 +153,21 @@ def initialize_dummy_model(
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
|
||||
current_device = torch.get_default_device()
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
with set_current_vllm_config(vllm_config=vllm_config):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
model = model_cls(vllm_config=vllm_config)
|
||||
torch.set_default_device(current_device)
|
||||
yield model
|
||||
|
||||
del model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_id", get_model_ids_to_test())
|
||||
def test_model_tensor_schema(model_id: str):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
|
||||
@ -173,10 +173,7 @@ class _HfExamplesInfo:
|
||||
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AfmoeForCausalLM": _HfExamplesInfo(
|
||||
"arcee-ai/Trinity-Nano",
|
||||
is_available_online=False,
|
||||
),
|
||||
"AfmoeForCausalLM": _HfExamplesInfo("arcee-ai/Trinity-Nano-Preview"),
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
|
||||
@ -359,7 +356,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"
|
||||
),
|
||||
"MixtralForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
@ -638,7 +635,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"HunYuanVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"tencent/HunyuanOCR",
|
||||
is_available_online=False,
|
||||
hf_overrides={"num_experts": 0},
|
||||
),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
@ -677,8 +674,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31",
|
||||
),
|
||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||
"lightonai/LightOnOCR-1B",
|
||||
is_available_online=False,
|
||||
"lightonai/LightOnOCR-1B-1025"
|
||||
),
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
@ -782,8 +778,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"ministral-3": "mistralai/Ministral-3-3B-Instruct-2512",
|
||||
},
|
||||
tokenizer_mode="mistral",
|
||||
# TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available.
|
||||
is_available_online=False,
|
||||
),
|
||||
"QwenVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen-VL",
|
||||
@ -846,7 +840,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
),
|
||||
# [Encoder-decoder]
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"),
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo(
|
||||
"openai/whisper-large-v3-turbo",
|
||||
extras={"v3": "openai/whisper-large-v3"},
|
||||
),
|
||||
# [Cross-encoder]
|
||||
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"),
|
||||
}
|
||||
@ -889,6 +886,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512",
|
||||
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
|
||||
# TODO: revert once figuring out OOM in CI
|
||||
is_available_online=False,
|
||||
),
|
||||
"LlamaForCausalLMEagle3": _HfExamplesInfo(
|
||||
|
||||
@ -10,9 +10,9 @@ import pytest
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
"This test only runs on Blackwell GPUs (SM100).", allow_module_level=True
|
||||
"This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction_mistral
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
parser_name = "mistral"
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
class StreamingReasoningReconstructor:
|
||||
|
||||
@ -7,7 +7,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.inputs import zip_enc_dec_prompts
|
||||
from vllm.inputs.parse import parse_raw_prompts
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.tokenizers import init_tokenizer_from_config
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -108,7 +108,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
||||
)
|
||||
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
||||
model_config = ModelConfig(model=model_id)
|
||||
tokenizer = init_tokenizer_from_config(model_config)
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||
|
||||
# HF processor adds sep token
|
||||
|
||||
@ -3,38 +3,39 @@
|
||||
from typing import _get_protocol_attrs # type: ignore
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
def _get_missing_attrs(obj: object, target: type):
|
||||
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
|
||||
|
||||
|
||||
def _assert_tokenizer_like(tokenizer: object):
|
||||
missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike)
|
||||
assert not missing_attrs, f"Missing attrs: {missing_attrs}"
|
||||
|
||||
|
||||
def test_tokenizer_like_protocol():
|
||||
assert not (
|
||||
missing_attrs := _get_missing_attrs(
|
||||
get_tokenizer("gpt2", use_fast=False),
|
||||
TokenizerLike,
|
||||
)
|
||||
), f"Missing attrs: {missing_attrs}"
|
||||
tokenizer = get_tokenizer("gpt2", use_fast=False)
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
_assert_tokenizer_like(tokenizer)
|
||||
|
||||
assert not (
|
||||
missing_attrs := _get_missing_attrs(
|
||||
get_tokenizer("gpt2", use_fast=True),
|
||||
TokenizerLike,
|
||||
)
|
||||
), f"Missing attrs: {missing_attrs}"
|
||||
tokenizer = get_tokenizer("gpt2", use_fast=True)
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerFast)
|
||||
_assert_tokenizer_like(tokenizer)
|
||||
|
||||
assert not (
|
||||
missing_attrs := _get_missing_attrs(
|
||||
get_tokenizer(
|
||||
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
|
||||
),
|
||||
TokenizerLike,
|
||||
)
|
||||
), f"Missing attrs: {missing_attrs}"
|
||||
tokenizer = get_tokenizer(
|
||||
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
|
||||
)
|
||||
assert isinstance(tokenizer, MistralTokenizer)
|
||||
_assert_tokenizer_like(tokenizer)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (
|
||||
FastIncrementalDetokenizer,
|
||||
|
||||
@ -2,7 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from pathlib import Path
|
||||
|
||||
from vllm.tokenizers import TokenizerLike, TokenizerRegistry, get_tokenizer
|
||||
import pytest
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.registry import (
|
||||
TokenizerRegistry,
|
||||
get_tokenizer,
|
||||
resolve_tokenizer_args,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenizer(TokenizerLike):
|
||||
@ -40,10 +47,22 @@ class TestTokenizer(TokenizerLike):
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runner_type", ["generate", "pooling"])
|
||||
def test_resolve_tokenizer_args_idempotent(runner_type):
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args(
|
||||
"facebook/opt-125m",
|
||||
runner_type=runner_type,
|
||||
)
|
||||
|
||||
assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args(
|
||||
tokenizer_name, *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def test_customized_tokenizer():
|
||||
TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc")
|
||||
tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.path_or_repo_id == "abc"
|
||||
assert tokenizer.bos_token_id == 0
|
||||
|
||||
@ -13,12 +13,9 @@ from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
|
||||
from vllm.tokenizers import (
|
||||
MistralTokenizer,
|
||||
TokenizerLike,
|
||||
get_tokenizer,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
||||
@ -22,10 +22,14 @@ from tests.v1.attention.utils import (
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness(
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(
|
||||
@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness(
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
|
||||
|
||||
|
||||
def _triton_convert_reference_impl(
|
||||
req_ids: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
token_indices: torch.Tensor,
|
||||
block_size: int,
|
||||
num_topk_tokens: int,
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Reference implementation for triton_convert_req_index_to_global_index."""
|
||||
num_tokens = req_ids.shape[0]
|
||||
max_blocks_per_req = block_table.shape[1]
|
||||
result = torch.empty(
|
||||
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
|
||||
)
|
||||
|
||||
for token_id in range(num_tokens):
|
||||
req_id = req_ids[token_id].item()
|
||||
|
||||
# Determine if this token uses workspace or paged cache
|
||||
use_prefill_workspace = False
|
||||
workspace_start = 0
|
||||
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
|
||||
assert prefill_workspace_starts is not None
|
||||
prefill_req_id = prefill_workspace_request_ids[token_id].item()
|
||||
if prefill_req_id >= 0:
|
||||
use_prefill_workspace = True
|
||||
workspace_start = prefill_workspace_starts[prefill_req_id].item()
|
||||
|
||||
for idx_id in range(num_topk_tokens):
|
||||
token_idx = token_indices[token_id, idx_id].item()
|
||||
|
||||
if token_idx == -1:
|
||||
result[token_id, idx_id] = -1
|
||||
elif use_prefill_workspace:
|
||||
# Prefill + using prefill workspace: map to workspace offset
|
||||
result[token_id, idx_id] = workspace_start + token_idx
|
||||
else:
|
||||
# Decode: map to paged cache
|
||||
block_id = token_idx // block_size
|
||||
if block_id >= max_blocks_per_req:
|
||||
result[token_id, idx_id] = -1
|
||||
else:
|
||||
block_num = block_table[req_id, block_id].item()
|
||||
offset = token_idx % block_size
|
||||
result[token_id, idx_id] = block_num * block_size + offset
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
||||
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_decode_only(
|
||||
block_size, num_topk_tokens
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 8
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 10
|
||||
|
||||
req_id = torch.randint(
|
||||
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
|
||||
)
|
||||
block_table = torch.randint(
|
||||
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(num_tokens, num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
|
||||
device = torch.device("cuda")
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 8
|
||||
num_topk_tokens = 128
|
||||
|
||||
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
|
||||
req_id = torch.tensor(
|
||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
|
||||
)
|
||||
prefill_workspace_request_ids = torch.tensor(
|
||||
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
|
||||
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
|
||||
|
||||
block_table = torch.randint(
|
||||
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(req_id.shape[0], num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,start,expected",
|
||||
"seq_lens,max_buf,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Non-zero start index
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would
|
||||
# overflow
|
||||
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||
(torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would overflow
|
||||
(torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||
# Large buffer with non-zero start
|
||||
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||
(torch.tensor([1, 1, 1]), 10, [(0, 3)]),
|
||||
# Large buffer
|
||||
(torch.tensor([4, 4, 4]), 100, [(0, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf)
|
||||
assert out == expected
|
||||
|
||||
@ -13,6 +13,7 @@ import torch
|
||||
|
||||
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
|
||||
# Detect Blackwell / B200 (compute capability 10.x)
|
||||
try:
|
||||
@ -44,6 +45,7 @@ DEEPEP_BACKENDS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_deep_ep(), reason="These tests require deep_ep to run")
|
||||
@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS)
|
||||
@pytest.mark.xfail(
|
||||
IS_BLACKWELL,
|
||||
|
||||
@ -16,6 +16,16 @@ from vllm.platforms import current_platform
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
"""Skip test if available GPUs < tp_size on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if available_gpus < tp_size:
|
||||
pytest.skip(
|
||||
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
|
||||
)
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
@ -455,6 +465,8 @@ def test_eagle_correctness(
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
max_model_len = 2048
|
||||
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
||||
|
||||
@ -525,6 +537,7 @@ def test_mtp_correctness(
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
|
||||
@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("top_logprobs", [0, 3])
|
||||
def test_spec_decode_logprobs(
|
||||
logprobs_mode: LogprobsMode,
|
||||
model_setup: tuple[str, str, str],
|
||||
top_logprobs: int,
|
||||
):
|
||||
"""Spec decode logprobs should match those of the base model.
|
||||
|
||||
@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
|
||||
|
||||
prompt = "Hello world " * 50
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
|
||||
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
|
||||
)
|
||||
method, model_name, spec_model_name = model_setup
|
||||
max_model_len = 256
|
||||
|
||||
@ -111,7 +111,7 @@ def create_sampling_metadata(
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
generators=generators,
|
||||
max_num_logprobs=0,
|
||||
max_num_logprobs=None,
|
||||
no_penalties=no_penalties,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=frequency_penalties,
|
||||
|
||||
@ -43,6 +43,7 @@ FILES = [
|
||||
"vllm/worker",
|
||||
"vllm/v1/core",
|
||||
"vllm/v1/engine",
|
||||
"vllm/v1/executor",
|
||||
"vllm/v1/metrics",
|
||||
"vllm/v1/pool",
|
||||
"vllm/v1/sample",
|
||||
@ -60,7 +61,6 @@ SEPARATE_GROUPS = [
|
||||
"vllm/model_executor",
|
||||
# v1 related
|
||||
"vllm/v1/attention",
|
||||
"vllm/v1/executor",
|
||||
"vllm/v1/kv_offload",
|
||||
"vllm/v1/spec_decode",
|
||||
"vllm/v1/structured_output",
|
||||
|
||||
@ -2403,6 +2403,29 @@ def cp_gather_cache(
|
||||
)
|
||||
|
||||
|
||||
def cp_gather_and_upconvert_fp8_kv_cache(
|
||||
src_cache: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
workspace_starts: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""Gather and upconvert FP8 KV cache to BF16 workspace.
|
||||
|
||||
Args:
|
||||
src_cache: FP8 KV cache [num_blocks, block_size, 656]
|
||||
dst: BF16 output workspace [total_tokens, 576]
|
||||
block_table: Block indices [num_reqs, max_blocks]
|
||||
seq_lens: Sequence lengths [num_reqs]
|
||||
workspace_starts: Workspace start offsets [num_reqs]
|
||||
batch_size: Number of requests
|
||||
"""
|
||||
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
|
||||
)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache(
|
||||
k: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer."""
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
@ -17,6 +18,7 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -524,6 +526,14 @@ class MultiHeadAttention(nn.Module):
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self.fa_version = None
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
self.fa_version = get_flash_attn_version()
|
||||
assert self._flash_attn_varlen_func is not None
|
||||
self._flash_attn_varlen_func = functools.partial(
|
||||
self._flash_attn_varlen_func, fa_version=self.fa_version
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
||||
)
|
||||
|
||||
@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
|
||||
@triton.jit
|
||||
def kernel_unified_attention_3d(
|
||||
segm_output_ptr,
|
||||
# [num_tokens, num_query_heads, num_segments, head_size]
|
||||
# [num_tokens, num_query_heads, num_segments, head_size_padded]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
@ -749,6 +749,11 @@ def unified_attention(
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
seq_threshold_3D=None,
|
||||
num_par_softmax_segments=None,
|
||||
softmax_segm_output=None,
|
||||
softmax_segm_max=None,
|
||||
softmax_segm_expsum=None,
|
||||
alibi_slopes=None,
|
||||
output_scale=None,
|
||||
qq_bias=None,
|
||||
@ -793,8 +798,19 @@ def unified_attention(
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
|
||||
# if batch contains a prefill
|
||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||
# Launch the 2D kernel if
|
||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||
# 2. The batch includes at least one prefill request, or
|
||||
# 3. The number of sequences exceeds the configured threshold
|
||||
if (
|
||||
seq_threshold_3D is None
|
||||
or num_par_softmax_segments is None
|
||||
or softmax_segm_output is None
|
||||
or softmax_segm_max is None
|
||||
or softmax_segm_expsum is None
|
||||
or max_seqlen_q > 1
|
||||
or num_seqs > seq_threshold_3D
|
||||
):
|
||||
kernel_unified_attention_2d[
|
||||
(
|
||||
total_num_q_blocks,
|
||||
@ -847,37 +863,12 @@ def unified_attention(
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
else:
|
||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||
# value that showed good performance in tests
|
||||
NUM_SEGMENTS = 16
|
||||
|
||||
segm_output = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
triton.next_power_of_2(head_size),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_max = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
segm_expsum = torch.empty(
|
||||
q.shape[0],
|
||||
num_query_heads,
|
||||
NUM_SEGMENTS,
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
kernel_unified_attention_3d[
|
||||
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
|
||||
](
|
||||
segm_output_ptr=softmax_segm_output,
|
||||
segm_max_ptr=softmax_segm_max,
|
||||
segm_expsum_ptr=softmax_segm_expsum,
|
||||
query_ptr=q,
|
||||
key_cache_ptr=k,
|
||||
value_cache_ptr=v,
|
||||
@ -917,13 +908,13 @@ def unified_attention(
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||
)
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
segm_output_ptr=segm_output,
|
||||
segm_max_ptr=segm_max,
|
||||
segm_expsum_ptr=segm_expsum,
|
||||
segm_output_ptr=softmax_segm_output,
|
||||
segm_max_ptr=softmax_segm_max,
|
||||
segm_expsum_ptr=softmax_segm_expsum,
|
||||
seq_lens_ptr=seqused_k,
|
||||
num_seqs=num_seqs,
|
||||
num_query_heads=num_query_heads,
|
||||
@ -936,6 +927,6 @@ def unified_attention(
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
|
||||
@ -141,7 +141,25 @@ class CompilerManager:
|
||||
# we use ast.literal_eval to parse the data
|
||||
# because it is a safe way to parse Python literals.
|
||||
# do not use eval(), it is unsafe.
|
||||
self.cache = ast.literal_eval(f.read())
|
||||
cache = ast.literal_eval(f.read())
|
||||
|
||||
def check_type(value, ty):
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
|
||||
|
||||
def parse_key(key: Any) -> tuple[Range, int, str]:
|
||||
range_tuple, graph_index, compiler_name = key
|
||||
check_type(graph_index, int)
|
||||
check_type(compiler_name, str)
|
||||
if isinstance(range_tuple, tuple):
|
||||
start, end = range_tuple
|
||||
check_type(start, int)
|
||||
check_type(end, int)
|
||||
range_tuple = Range(start=start, end=end)
|
||||
check_type(range_tuple, Range)
|
||||
return range_tuple, graph_index, compiler_name
|
||||
|
||||
self.cache = {parse_key(key): value for key, value in cache.items()}
|
||||
|
||||
self.compiler.initialize_cache(
|
||||
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||
|
||||
@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import supports_dynamo
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
@ -316,7 +316,13 @@ def _support_torch_compile(
|
||||
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
||||
def mark_dynamic(arg, dims):
|
||||
if type == DynamicShapesType.UNBACKED:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
@ -350,7 +356,13 @@ def _support_torch_compile(
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
@ -488,6 +500,12 @@ def _support_torch_compile(
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
# Prepare inductor config patches
|
||||
# assume_32bit_indexing is only available in torch 2.10.0.dev+
|
||||
inductor_config_patches = {}
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
inductor_config_patches["assume_32bit_indexing"] = True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
@ -496,6 +514,7 @@ def _support_torch_compile(
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
torch._inductor.config.patch(**inductor_config_patches),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
|
||||
@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
with aot_context:
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
@ -539,6 +539,11 @@ class ModelConfig:
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.mm_processor_cache_gb = 0
|
||||
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
|
||||
|
||||
# Init multimodal config if needed
|
||||
if self._model_info.supports_multimodal:
|
||||
if (
|
||||
|
||||
@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
)
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
def getattr_iter(
|
||||
object: object, names: Iterable[str], default: Any, warn: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
|
||||
In the case where the first name in `names` is the preferred name, and
|
||||
any other names are deprecated aliases, setting `warn=True` will log a
|
||||
warning when a deprecated name is used.
|
||||
"""
|
||||
for name in names:
|
||||
for i, name in enumerate(names):
|
||||
if hasattr(object, name):
|
||||
if warn and i > 0:
|
||||
logger.warning_once(
|
||||
"%s contains a deprecated attribute name '%s'. "
|
||||
"Please use the preferred attribute name '%s' instead.",
|
||||
type(object).__name__,
|
||||
name,
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
@ -750,27 +750,17 @@ class VllmConfig:
|
||||
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
|
||||
self._set_compile_ranges()
|
||||
|
||||
if self.model_config and self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
if (
|
||||
self.model_config
|
||||
and self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens,
|
||||
)
|
||||
if (
|
||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
|
||||
if (
|
||||
self.kv_events_config is not None
|
||||
|
||||
@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
|
||||
LMCacheAsyncLookupServer,
|
||||
)
|
||||
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
||||
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
|
||||
|
||||
try:
|
||||
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
|
||||
except ImportError:
|
||||
# Backwards compatibility for lmcache <= 0.3.10-post1
|
||||
from lmcache.v1.plugin.plugin_launcher import (
|
||||
PluginLauncher as RuntimePluginLauncher,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -1649,7 +1649,13 @@ class EngineArgs:
|
||||
"attention_backend and attention_config.backend "
|
||||
"are mutually exclusive"
|
||||
)
|
||||
attention_config.backend = self.attention_backend
|
||||
# Convert string to enum if needed (CLI parsing returns a string)
|
||||
if isinstance(self.attention_backend, str):
|
||||
attention_config.backend = AttentionBackendEnum[
|
||||
self.attention_backend.upper()
|
||||
]
|
||||
else:
|
||||
attention_config.backend = self.attention_backend
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
|
||||
@ -324,12 +324,12 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
id=origin_chunk.id,
|
||||
content=[],
|
||||
model=origin_chunk.model,
|
||||
),
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
output_tokens=0,
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
output_tokens=0,
|
||||
),
|
||||
),
|
||||
)
|
||||
first_item = False
|
||||
|
||||
@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
@ -49,11 +49,20 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -620,6 +629,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _extract_embeds(tensors: list[torch.Tensor]):
|
||||
if len(tensors) == 0:
|
||||
return tensors
|
||||
|
||||
if len(tensors) == 1:
|
||||
tensors[0]._is_single_item = True # type: ignore
|
||||
return tensors[0] # To keep backwards compatibility for single item input
|
||||
|
||||
first_shape = tensors[0].shape
|
||||
if all(t.shape == first_shape for t in tensors):
|
||||
return torch.stack(tensors)
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
|
||||
embeds_key = f"{modality}_embeds"
|
||||
embeds = items_by_modality[embeds_key]
|
||||
|
||||
if len(embeds) == 0:
|
||||
return embeds
|
||||
if is_list_of(embeds, torch.Tensor):
|
||||
return _extract_embeds(embeds)
|
||||
if is_list_of(embeds, dict):
|
||||
if not embeds:
|
||||
return {}
|
||||
|
||||
first_keys = set(embeds[0].keys())
|
||||
if any(set(item.keys()) != first_keys for item in embeds[1:]):
|
||||
raise ValueError(
|
||||
"All dictionaries in the list of embeddings must have the same keys."
|
||||
)
|
||||
|
||||
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
@ -688,11 +735,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_uuids = {}
|
||||
|
||||
uuids_by_modality = dict(self._uuids_by_modality)
|
||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_uuids = {}
|
||||
if "image_embeds" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||
if "image" in uuids_by_modality:
|
||||
@ -703,6 +753,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||
if "video" in uuids_by_modality:
|
||||
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
||||
|
||||
return mm_uuids
|
||||
|
||||
@abstractmethod
|
||||
@ -714,29 +765,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
mm_inputs["image"] = (
|
||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||
mm_inputs["audio"] = (
|
||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
@ -747,38 +794,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
async def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = {}
|
||||
for modality, items in self._items_by_modality.items():
|
||||
coros = []
|
||||
for item in items:
|
||||
if item is not None:
|
||||
coros.append(item)
|
||||
else:
|
||||
coros.append(asyncio.sleep(0))
|
||||
items_by_modality[modality] = await asyncio.gather(*coros)
|
||||
|
||||
coros_by_modality = {
|
||||
modality: [item or asyncio.sleep(0) for item in items]
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
items_by_modality: dict[str, list[object | None]] = {
|
||||
modality: await asyncio.gather(*coros)
|
||||
for modality, coros in coros_by_modality.items()
|
||||
}
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
mm_inputs["image"] = (
|
||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||
mm_inputs["audio"] = (
|
||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
@ -1792,7 +1833,7 @@ def apply_hf_chat_template(
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
tokenizer: "MistralTokenizer",
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
|
||||
@ -72,7 +72,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.collection_utils import as_iter, is_list_of
|
||||
from vllm.utils.counter import Counter
|
||||
|
||||
@ -232,7 +232,177 @@ def parse_response_input(
|
||||
return msg
|
||||
|
||||
|
||||
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
|
||||
"""
|
||||
Parse a list of messages from request.messages in the Chat Completion API to
|
||||
Harmony messages.
|
||||
"""
|
||||
msgs: list[Message] = []
|
||||
tool_id_names: dict[str, str] = {}
|
||||
|
||||
# Collect tool id to name mappings for tool response recipient values
|
||||
for chat_msg in chat_msgs:
|
||||
for tool_call in chat_msg.get("tool_calls", []):
|
||||
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
|
||||
"name"
|
||||
)
|
||||
|
||||
for chat_msg in chat_msgs:
|
||||
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
|
||||
|
||||
msgs = auto_drop_analysis_messages(msgs)
|
||||
return msgs
|
||||
|
||||
|
||||
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
|
||||
"""
|
||||
Harmony models expect the analysis messages (representing raw chain of thought) to
|
||||
be dropped after an assistant message to the final channel is produced from the
|
||||
reasoning of those messages.
|
||||
|
||||
The openai-harmony library does this if the very last assistant message is to the
|
||||
final channel, but it does not handle the case where we're in longer multi-turn
|
||||
conversations and the client gave us reasoning content from previous turns of
|
||||
the conversation with multiple assistant messages to the final channel in the
|
||||
conversation.
|
||||
|
||||
So, we find the index of the last assistant message to the final channel and drop
|
||||
all analysis messages that precede it, leaving only the analysis messages that
|
||||
are relevant to the current part of the conversation.
|
||||
"""
|
||||
last_assistant_final_index = -1
|
||||
for i in range(len(msgs) - 1, -1, -1):
|
||||
msg = msgs[i]
|
||||
if msg.author.role == "assistant" and msg.channel == "final":
|
||||
last_assistant_final_index = i
|
||||
break
|
||||
|
||||
cleaned_msgs: list[Message] = []
|
||||
for i, msg in enumerate(msgs):
|
||||
if i < last_assistant_final_index and msg.channel == "analysis":
|
||||
continue
|
||||
cleaned_msgs.append(msg)
|
||||
|
||||
return cleaned_msgs
|
||||
|
||||
|
||||
def flatten_chat_text_content(content: str | list | None) -> str | None:
|
||||
"""
|
||||
Extract the text parts from a chat message content field and flatten them
|
||||
into a single string.
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names: dict[str, str] | None = None
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Parse a message from request.messages in the Chat Completion API to
|
||||
Harmony messages.
|
||||
"""
|
||||
tool_id_names = tool_id_names or {}
|
||||
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
role = chat_msg.get("role")
|
||||
msgs: list[Message] = []
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls", [])
|
||||
|
||||
if role == "assistant" and tool_calls:
|
||||
content = flatten_chat_text_content(chat_msg.get("content"))
|
||||
if content:
|
||||
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
|
||||
commentary_msg = commentary_msg.with_channel("commentary")
|
||||
msgs.append(commentary_msg)
|
||||
|
||||
reasoning_content = chat_msg.get("reasoning") or chat_msg.get(
|
||||
"reasoning_content"
|
||||
)
|
||||
if reasoning_content:
|
||||
analysis_msg = Message.from_role_and_content(
|
||||
Role.ASSISTANT, reasoning_content
|
||||
)
|
||||
analysis_msg = analysis_msg.with_channel("analysis")
|
||||
msgs.append(analysis_msg)
|
||||
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
# Officially, this should be `<|constrain|>json` but there is not clear
|
||||
# evidence that improves accuracy over `json` and some anecdotes to the
|
||||
# contrary. Further testing of the different content_types is needed.
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
tool_call_id = chat_msg.get("tool_call_id", "")
|
||||
name = tool_id_names.get(tool_call_id, "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
content = flatten_chat_text_content(content)
|
||||
|
||||
msg = (
|
||||
Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
return [msg]
|
||||
|
||||
# Non-tool reasoning content
|
||||
reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content")
|
||||
if role == "assistant" and reasoning_content:
|
||||
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content)
|
||||
analysis_msg = analysis_msg.with_channel("analysis")
|
||||
msgs.append(analysis_msg)
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
content = chat_msg.get("content") or ""
|
||||
if content is None:
|
||||
content = ""
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
|
||||
# Only add assistant messages if they have content, as reasoning or tool calling
|
||||
# assistant messages were already added above.
|
||||
if role == "assistant" and contents and contents[0].text:
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
# Send non-tool assistant messages to the final channel
|
||||
msg = msg.with_channel("final")
|
||||
msgs.append(msg)
|
||||
# For user/system/developer messages, add them directly even if no content.
|
||||
elif role != "assistant":
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
msgs.append(msg)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
"""
|
||||
Parse a message from request.previous_input_messages in the Responsees API to
|
||||
Harmony messages.
|
||||
"""
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
if isinstance(content, list):
|
||||
# Handle array format for tool message content
|
||||
# by concatenating all text parts.
|
||||
content = "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
)
|
||||
content = flatten_chat_text_content(content)
|
||||
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||
def parse_chat_output(
|
||||
token_ids: Sequence[int],
|
||||
) -> tuple[str | None, str | None, bool]:
|
||||
"""
|
||||
Parse the output of a Harmony chat completion into reasoning and final content.
|
||||
Note that when the `openai` tool parser is used, serving_chat only uses this
|
||||
for the reasoning content and gets the final content from the tool call parser.
|
||||
|
||||
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
|
||||
in use,this needs to return the final content without any tool calls parsed.
|
||||
|
||||
Empty reasoning or final content is returned as None instead of an empty string.
|
||||
"""
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
output_msgs = parser.messages
|
||||
is_tool_call = False # TODO: update this when tool call is supported
|
||||
if len(output_msgs) == 0:
|
||||
# The generation has stopped during reasoning.
|
||||
reasoning = parser.current_content
|
||||
final_content = None
|
||||
elif len(output_msgs) == 1:
|
||||
# The generation has stopped during final message.
|
||||
reasoning = output_msgs[0].content[0].text
|
||||
final_content = parser.current_content
|
||||
else:
|
||||
reasoning_msg = output_msgs[:-1]
|
||||
final_msg = output_msgs[-1]
|
||||
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
|
||||
final_content = final_msg.content[0].text
|
||||
|
||||
# Get completed messages from the parser
|
||||
reasoning_texts = [
|
||||
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
|
||||
]
|
||||
final_texts = [
|
||||
msg.content[0].text for msg in output_msgs if msg.channel != "analysis"
|
||||
]
|
||||
|
||||
# Extract partial messages from the parser
|
||||
if parser.current_channel == "analysis" and parser.current_content:
|
||||
reasoning_texts.append(parser.current_content)
|
||||
elif parser.current_channel != "analysis" and parser.current_content:
|
||||
final_texts.append(parser.current_content)
|
||||
|
||||
# Flatten multiple messages into a single string
|
||||
reasoning: str | None = "\n".join(reasoning_texts)
|
||||
final_content: str | None = "\n".join(final_texts)
|
||||
|
||||
# Return None instead of empty string since existing callers check for None
|
||||
reasoning = reasoning or None
|
||||
final_content = final_content or None
|
||||
|
||||
return reasoning, final_content, is_tool_call
|
||||
|
||||
@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant,
|
||||
get_system_message,
|
||||
parse_chat_inputs_to_harmony_messages,
|
||||
parse_chat_output,
|
||||
parse_input_to_harmony_message,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -822,6 +822,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if delta_message is not None:
|
||||
harmony_tools_streamed[i] = True
|
||||
elif cur_channel == "commentary":
|
||||
# Tool call preambles meant to be shown to the user
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = None
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
@ -1770,6 +1773,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
):
|
||||
messages: list[OpenAIMessage] = []
|
||||
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
# if the model supports it. TODO: Support browsing.
|
||||
@ -1788,8 +1796,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.extend(parse_input_to_harmony_message(chat_msg))
|
||||
messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
|
||||
@ -117,7 +117,9 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.deepseekv32 import DeepseekV32Tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
|
||||
@ -22,7 +22,8 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -21,7 +21,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user