Merge branch 'main' into model_arch_cfg

This commit is contained in:
Xingyu Liu 2025-12-12 15:30:09 -08:00 committed by GitHub
commit 97e1752922
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
226 changed files with 7491 additions and 1451 deletions

View File

@ -15,6 +15,21 @@ steps:
env:
DOCKER_BUILDKIT: "1"
- label: "Build arm64 wheel - CUDA 13.0"
depends_on: ~
id: build-wheel-arm64-cuda-13-0
agents:
queue: arm64_cpu_queue_postmerge
commands:
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
env:
DOCKER_BUILDKIT: "1"
# aarch64 build
- label: "Build arm64 CPU wheel"
depends_on: ~
@ -25,7 +40,7 @@ steps:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
env:
DOCKER_BUILDKIT: "1"
@ -39,7 +54,7 @@ steps:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31"
env:
DOCKER_BUILDKIT: "1"
@ -52,7 +67,21 @@ steps:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
env:
DOCKER_BUILDKIT: "1"
# x86 CPU wheel build
- label: "Build x86 CPU wheel"
depends_on: ~
id: build-wheel-x86-cpu
agents:
queue: cpu_queue_postmerge
commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
env:
DOCKER_BUILDKIT: "1"

View File

@ -372,6 +372,17 @@ if __name__ == "__main__":
print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
# keep only "official" files for a non-nightly version (specifed by cli args)
PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$")
if PY_VERSION_RE.match(version):
# upload-wheels.sh ensures no "dev" is in args.version
wheel_files = list(
filter(lambda x: version in x and "dev" not in x, wheel_files)
)
print(f"Non-nightly version detected, wheel files used: {wheel_files}")
else:
print("Nightly version detected, keeping all wheel files.")
# Generate index and metadata, assuming wheels and indices are stored as:
# s3://vllm-wheels/{version}/<wheel files>
# s3://vllm-wheels/<anything>/<index files>

View File

@ -34,9 +34,10 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then
fi
wheel="${wheel_files[0]}"
# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31
# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31
# we also accept params as manylinux tag
# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels
manylinux_version="manylinux_2_31"
manylinux_version="${1:-manylinux_2_31}"
# Rename 'linux' to the appropriate manylinux version in the wheel filename
if [[ "$wheel" != *"linux"* ]]; then
@ -96,8 +97,11 @@ if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]];
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/"
fi
# copy to /<pure_version>/ only if it does not have "dev" in the version
# re-generate and copy to /<pure_version>/ only if it does not have "dev" in the version
if [[ "$version" != *"dev"* ]]; then
echo "Uploading indices to overwrite /$pure_version/"
echo "Re-generating indices for /$pure_version/"
rm -rf "$INDICES_OUTPUT_DIR/*"
mkdir -p "$INDICES_OUTPUT_DIR"
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
fi

View File

@ -326,10 +326,10 @@ steps:
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
- label: V1 Test e2e + engine # 65min
timeout_in_minutes: 90
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
agent_pool: mi325_4
# grade: Blocking
source_file_dependencies:
- vllm/
@ -435,7 +435,7 @@ steps:
- label: Examples Test # 30min
timeout_in_minutes: 45
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
working_dir: "/vllm-workspace/examples"
@ -455,7 +455,6 @@ steps:
# for multi-modal models
- python3 offline_inference/audio_language.py --seed 0
- python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_pooling.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
# for pooling models

View File

@ -836,7 +836,7 @@ steps:
- tests/models/multimodal
no_gpu: true
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
- label: Multi-Modal Processor Test
@ -1346,6 +1346,7 @@ steps:
- label: Prime-RL Integration Test # 15min
timeout_in_minutes: 30
optional: true
soft_fail: true
num_gpus: 2
working_dir: "/vllm-workspace"
source_file_dependencies:
@ -1379,4 +1380,4 @@ steps:
num_gpus: 2
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1

View File

@ -384,7 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
PYTHONPATH=$ENV{PYTHONPATH}
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
@ -822,7 +822,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH}
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
@ -1004,7 +1004,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
PYTHONPATH=$ENV{PYTHONPATH}
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output

View File

@ -32,7 +32,6 @@ def benchmark_propose(args):
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
max_model_len=args.num_token + args.num_spec_token,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",

View File

@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
in MLA (Multi-head Latent Attention) prefill.
This validates that the optimization from commit 8d4142bd is beneficial across
various batch sizes, not just the originally tested batch size of 32768.
"""
import time
from collections.abc import Callable
import torch
# DeepSeek-V3 MLA dimensions
NUM_HEADS = 128
QK_NOPE_HEAD_DIM = 128
PE_DIM = 64
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
"""Original torch.cat approach with expand."""
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
"""Optimized direct copy approach (avoids expand + cat overhead)."""
k = torch.empty(
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
dtype=k_nope.dtype,
device=k_nope.device,
)
k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe
return k
def benchmark_method(
method: Callable,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
num_warmup: int = 10,
num_iters: int = 100,
) -> float:
"""Benchmark a concatenation method and return mean latency in ms."""
# Warmup
for _ in range(num_warmup):
_ = method(k_nope, k_pe)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iters):
_ = method(k_nope, k_pe)
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / num_iters * 1000 # Convert to ms
@torch.inference_mode()
def run_benchmark(dtype: torch.dtype, dtype_name: str):
"""Run benchmark for a specific dtype."""
torch.set_default_device("cuda")
# Batch sizes to test (powers of 2 from 32 to 65536)
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
print("=" * 80)
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
print("=" * 80)
print(
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
f"k_pe=[B, 1, {PE_DIM}]"
)
print(f"dtype: {dtype_name}")
print()
print(
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
f"{'Speedup':>8} | {'Reduction':>10}"
)
print("-" * 70)
results = []
for batch_size in batch_sizes:
# Create input tensors (generate in float32 then convert for FP8 compatibility)
k_nope = torch.randn(
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
).to(dtype)
k_pe = torch.randn(
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
).to(dtype)
# Benchmark both methods
cat_time = benchmark_method(cat_method, k_nope, k_pe)
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
speedup = cat_time / direct_time
reduction = (1 - direct_time / cat_time) * 100
results.append((batch_size, cat_time, direct_time, speedup, reduction))
print(
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
)
print("=" * 80)
# Summary statistics
speedups = [r[3] for r in results]
print("\nSpeedup summary:")
print(f" Min: {min(speedups):.2f}x")
print(f" Max: {max(speedups):.2f}x")
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
# Find crossover point
crossover_batch = None
for batch_size, _, _, speedup, _ in results:
if speedup >= 1.0:
crossover_batch = batch_size
break
print("\nConclusion:")
if crossover_batch:
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
# Filter for large batches (>= 512 which is typical for prefill)
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
if large_batch_speedups:
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
print(" - MLA prefill typically uses large batches, so optimization is effective")
return results
@torch.inference_mode()
def main():
# Test bfloat16
print("\n")
run_benchmark(torch.bfloat16, "bfloat16")
# Test float8_e4m3fn
print("\n")
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
if __name__ == "__main__":
main()

View File

@ -99,7 +99,6 @@ def benchmark_mrope(
# the parameters to compute the q k v size based on tp_size
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=rope_parameters,

View File

@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16
max_position = 8192
base = 10000
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=device)
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)

View File

@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
run_python(_VLLM_TORCH_GOMP_PATH
"
import os, glob
try:
import torch
torch_pkg = os.path.dirname(torch.__file__)
site_root = os.path.dirname(torch_pkg)
torch_libs = os.path.join(site_root, 'torch.libs')
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
except:
print('')
import torch
torch_pkg = os.path.dirname(torch.__file__)
site_root = os.path.dirname(torch_pkg)
# Search both torch.libs and torch/lib
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
candidates = []
for root in roots:
if not os.path.isdir(root):
continue
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
print(candidates[0] if candidates else '')
"
"failed to probe torch.libs for libgomp")
"failed to probe for libgomp")
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
return()

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/all.h>
#include <c10/util/Optional.h>
#include <map>
#include <vector>
@ -58,6 +59,15 @@ void cp_gather_cache(
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]

View File

@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Optional.h>
#include "cuda_utils.h"
#include "cuda_compat.h"
@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache(
}
namespace vllm {
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
const bool is_active_split = (split_start < tot_slots);
if (!is_active_split) return;
// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;
const int tid = threadIdx.x;
// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
}
template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
@ -1202,6 +1280,57 @@ void cp_gather_cache(
}
}
void cp_gather_and_upconvert_fp8_kv_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
torch::Tensor const& dst, // [TOT_TOKENS, 576]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& seq_lens, // [BATCH]
torch::Tensor const& workspace_starts, // [BATCH]
int64_t batch_size) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t head_dim = dst.size(1);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
"workspace_starts must be int32");
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == seq_lens.device(),
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_cache.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \

View File

@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
largest = value;
}
}
__syncwarp(); // Ensure all threads have valid data before reduction
// Get the top2 warpwise
T max1 = cg::reduce(tile, largest, cg::greater<T>());
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = neg_inf<T>();
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}
// Load the valid score value

View File

@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}
}

View File

@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) {
return val;
}
template <typename T, bool SCALE_UE8M0>
__device__ __forceinline__ float ComputeGroupScale(
const T* __restrict__ group_input, T* __restrict__ smem_group,
const int group_size, const int lane_id, const int threads_per_group,
const float eps, const float max_8bit) {
float local_absmax = eps;
constexpr int vec_size = 16 / sizeof(T);
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
return y_s;
}
template <typename T, typename DST_DTYPE>
__device__ __forceinline__ void QuantizeGroup(
const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output,
const int group_size, const int lane_id, const int threads_per_group,
const float y_s, const float min_8bit, const float max_8bit) {
constexpr int vec_size = 16 / sizeof(T);
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
__global__ void per_token_group_quant_8bit_kernel(
@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel(
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = float;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel(
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
const float y_s = ComputeGroupScale<T, SCALE_UE8M0>(
group_input, smem_group, group_size, lane_id, threads_per_group, eps,
max_8bit);
scale_element_t y_s_quant = y_s;
@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel(
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
threads_per_group, y_s, min_8bit, max_8bit);
}
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
inline int GetGroupsPerBlock(int64_t num_groups) {
if (num_groups % 16 == 0) {
return 16;
}
if (num_groups % 8 == 0) {
return 8;
}
if (num_groups % 4 == 0) {
return 4;
}
if (num_groups % 2 == 0) {
return 2;
}
return 1;
}
void per_token_group_quant_8bit(const torch::Tensor& input,
@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
const int groups_per_block = GetGroupsPerBlock(num_groups);
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
const float y_s =
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
threads_per_group, eps, max_8bit);
// pack 4 scales into a uint32
if (lane_id == 0) {
@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
threads_per_group, y_s, min_8bit, max_8bit);
}
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
const int groups_per_block = GetGroupsPerBlock(num_groups);
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;

View File

@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "

View File

@ -26,3 +26,4 @@ The backends below live **outside** the main `vllm` repository and follow the
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |

View File

@ -29,8 +29,27 @@ uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.a
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
!!! note
Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression).
**Install the latest code**
LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides working pre-built Arm CPU wheels for every commit since `v0.11.2` on <https://wheels.vllm.ai/nightly>. For native CPU wheels, this index should be used:
* `https://wheels.vllm.ai/nightly/cpu/vllm`
To install from nightly index, copy the link address of the `*.whl` under this index to run, for example:
```bash
uv pip install -U https://wheels.vllm.ai/c756fb678184b867ed94e5613a529198f1aee423/vllm-0.13.0rc2.dev11%2Bgc756fb678.cpu-cp38-abi3-manylinux_2_31_aarch64.whl # current nightly build (the filename will change!)
```
**Install specific revisions**
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), specify the full commit hash in the index:
https://wheels.vllm.ai/${VLLM_COMMIT}/cpu/vllm .
Then, copy the link address of the `*.whl` under this index to run:
```bash
uv pip install -U <wheel-url>
```
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]
@ -81,7 +100,23 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
# --8<-- [end:build-wheel-from-source]
# --8<-- [start:pre-built-images]
Currently, there are no pre-built Arm CPU images.
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
Please replace `<version>` in the command below with a specific version string (e.g., `0.12.0`).
```bash
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v<version>
```
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.
The latest code can contain bugs and may not be stable. Please use it with caution.
```bash
export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu
```
# --8<-- [end:pre-built-images]
# --8<-- [start:build-image-from-source]

View File

@ -316,10 +316,13 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
### Remove softmax from PoolingParams
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
We are going to remove `softmax` and `activation` from `PoolingParams` in v0.15. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
### as_reward_model
!!! warning
We are going to remove `--convert reward` in v0.15, use `--convert embed` instead.
Pooling models now default support all pooling, you can use it without any settings.
- Extracting hidden states prefers using `token_embed` task.

View File

@ -568,7 +568,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
```
!!! note
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker.py](../../examples/pooling/score/qwen3_reranker.py).
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py).
```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'

View File

@ -851,7 +851,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
popular open-source tools.
Code example: [examples/pooling/score/jinaai_rerank_client.py](../../examples/pooling/score/jinaai_rerank_client.py)
Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py)
#### Example Request

View File

@ -4,6 +4,9 @@
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
@ -20,6 +23,11 @@ def parse_args():
def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts.
prompts = [
"Hello, my name is",

View File

@ -4,6 +4,9 @@
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
@ -20,6 +23,11 @@ def parse_args():
def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts.
text_1 = "What is the capital of France?"
texts_2 = [

View File

@ -33,6 +33,7 @@ import os
from time import sleep
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
@ -222,6 +223,11 @@ if __name__ == "__main__":
from multiprocessing import Process
if current_platform.is_rocm():
from multiprocessing import set_start_method
set_start_method("spawn", force=True)
procs = []
for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)

View File

@ -21,7 +21,7 @@
# --worker \
# /abs/path/to/huggingface/cache \
# -e VLLM_HOST_IP=<worker_node_ip>
#
#
# Each worker requires a unique VLLM_HOST_IP value.
# Keep each terminal session open. Closing a session stops the associated Ray
# node and thereby shuts down the entire cluster.
@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
exit 1
fi
# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=...").
VLLM_HOST_IP=""
for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do
arg="${ADDITIONAL_ARGS[$i]}"
case "${arg}" in
-e)
next="${ADDITIONAL_ARGS[$((i + 1))]:-}"
if [[ "${next}" == VLLM_HOST_IP=* ]]; then
VLLM_HOST_IP="${next#VLLM_HOST_IP=}"
break
fi
;;
-eVLLM_HOST_IP=* | VLLM_HOST_IP=*)
VLLM_HOST_IP="${arg#*=}"
break
;;
esac
done
# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent.
if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then
if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then
echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})."
echo "Using VLLM_HOST_IP as the head node address."
HEAD_NODE_ADDRESS="${VLLM_HOST_IP}"
fi
fi
# Generate a unique container name with random suffix.
# Docker container names must be unique on each host.
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
@ -74,36 +102,17 @@ cleanup() {
trap cleanup EXIT
# Build the Ray start command based on the node role.
# The head node manages the cluster and accepts connections on port 6379,
# The head node manages the cluster and accepts connections on port 6379,
# while workers connect to the head's address.
RAY_START_CMD="ray start --block"
if [ "${NODE_TYPE}" == "--head" ]; then
RAY_START_CMD+=" --head --port=6379"
RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379"
else
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
fi
# Parse VLLM_HOST_IP from additional args if present.
# This is needed for multi-NIC configurations where Ray needs explicit IP bindings.
VLLM_HOST_IP=""
for arg in "${ADDITIONAL_ARGS[@]}"; do
if [[ $arg == "-e" ]]; then
continue
if [ -n "${VLLM_HOST_IP}" ]; then
RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}"
fi
if [[ $arg == VLLM_HOST_IP=* ]]; then
VLLM_HOST_IP="${arg#VLLM_HOST_IP=}"
break
fi
done
# Build Ray IP environment variables if VLLM_HOST_IP is set.
# These variables ensure Ray binds to the correct network interface on multi-NIC systems.
RAY_IP_VARS=()
if [ -n "${VLLM_HOST_IP}" ]; then
RAY_IP_VARS=(
-e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
-e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
)
fi
# Launch the container with the assembled parameters.
@ -118,6 +127,5 @@ docker run \
--shm-size 10.24g \
--gpus all \
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
"${RAY_IP_VARS[@]}" \
"${ADDITIONAL_ARGS[@]}" \
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"

View File

@ -50,4 +50,5 @@ ijson # Required for mistral streaming tool parser
setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic == 0.71.0
model-hosting-container-standards >= 0.1.9, < 1.0.0
model-hosting-container-standards >= 0.1.9, < 1.0.0
mcp

View File

@ -1,2 +1,2 @@
lmcache >= 0.3.10.post1
lmcache
nixl >= 0.7.1 # Required for disaggregated prefill

View File

@ -138,6 +138,17 @@ elif current_platform.is_rocm():
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def has_cuda_graph_wrapper_metadata() -> bool:
from importlib import import_module
try:
module = import_module("torch._inductor.utils")
module.CUDAGraphWrapperMetadata # noqa B018
except AttributeError:
return False
return True
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
@ -145,7 +156,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
# quant_fp4 only has the custom impl
+ list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.parametrize(
"inductor_graph_partition",
[
pytest.param(
True,
marks=pytest.mark.skipif(
not has_cuda_graph_wrapper_metadata(),
reason="This test requires"
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
),
),
False,
],
)
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],

View File

@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
def __init__(self, head_dim=64, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base},
)
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters={"rope_type": "default", "rope_theta": base},
)

View File

@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup_dist_env_and_memory()
@pytest.fixture
def workspace_init():
"""Initialize the workspace manager for tests that need it.
This fixture initializes the workspace manager with a CUDA device
if available, and resets it after the test completes. Tests that
create a full vLLM engine should NOT use this fixture as the engine
will initialize the workspace manager itself.
"""
from vllm.v1.worker.workspace import (
init_workspace_manager,
reset_workspace_manager,
)
if torch.cuda.is_available():
device = torch.device("cuda:0")
init_workspace_manager(device)
yield
reset_workspace_manager()
@pytest.fixture(autouse=True)
def dynamo_reset():
yield

View File

@ -0,0 +1,276 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
from dataclasses import dataclass
import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config,
ModelOptNvFp4FusedMoE,
)
from .eplb_utils import distributed_run, set_env_vars_and_device
@dataclass
class TestConfig:
num_layers: int
num_experts: int
num_local_experts: int
num_topk: int
hidden_size: int
intermediate_size: int
num_tokens: int
def make_fused_moe_layer(
rank: int,
layer_idx: int,
test_config: TestConfig,
) -> FusedMoE:
quant_config = None
device = torch.device(f"cuda:{rank}")
quant_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)
fml = FusedMoE(
num_experts=test_config.num_experts,
top_k=test_config.num_topk,
hidden_size=test_config.hidden_size,
intermediate_size=test_config.intermediate_size,
prefix=f"dummy_layer_{layer_idx}",
activation="silu",
is_act_and_mul=True,
params_dtype=torch.bfloat16,
quant_config=quant_config,
)
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
nvfp4_fused_moe.create_weights(
fml,
test_config.num_local_experts,
test_config.hidden_size,
test_config.intermediate_size,
params_dtype=torch.uint8,
global_num_experts=test_config.num_experts,
)
fml = fml.to(device)
w1_q, w2_q, quant_config = make_test_quant_config(
test_config.num_local_experts,
test_config.intermediate_size,
test_config.hidden_size,
in_dtype=torch.bfloat16,
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
)
fml.w13_weight.data = w1_q
fml.w2_weight.data = w2_q
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
fml.w2_weight_scale.data = (
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
).to(fml.w2_weight_scale.data.dtype)
fml.w13_weight_scale.data = (
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
).to(fml.w13_weight_scale.data.dtype)
nvfp4_fused_moe.process_weights_after_loading(fml)
fml.maybe_init_modular_kernel()
return fml
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
)
ep_group = get_dp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
fml_layers = [
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
for layer_idx in range(test_config.num_layers)
]
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
hidden_states = []
router_logits = []
for layer_idx in range(test_config.num_layers):
hidden_states.append(
torch.randn(
(test_config.num_tokens, test_config.hidden_size),
dtype=torch.bfloat16,
device=device,
)
)
router_logits.append(
torch.randn(
(test_config.num_tokens, test_config.num_experts),
dtype=torch.bfloat16,
device=device,
)
)
out_before_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_before_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)
indices = torch.zeros(
test_config.num_layers, test_config.num_experts, dtype=torch.long
)
for lidx in range(test_config.num_layers):
indices[lidx] = torch.Tensor(range(test_config.num_experts))
shuffled_indices = torch.zeros_like(indices)
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
ep_group,
is_profile=False,
)
num_global_experts = test_config.num_experts
logical_to_physical_map_list = []
for lidx, fml in enumerate(fml_layers):
physical_to_logical_map = shuffled_indices[lidx].to(device)
logical_to_physical_map = torch.empty(
(num_global_experts,), dtype=torch.int32, device=device
)
logical_to_physical_map[physical_to_logical_map] = torch.arange(
0, num_global_experts, dtype=torch.int32, device=device
)
logical_to_physical_map_list.append(
logical_to_physical_map.reshape(num_global_experts, 1)
)
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
for lidx, fml in enumerate(fml_layers):
logical_replica_count = torch.ones(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
)
fml.enable_eplb = True
fml.set_eplb_state(
lidx,
torch.zeros(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
),
logical_to_physical_map,
logical_replica_count,
)
out_after_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_after_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)
for lidx in range(test_config.num_layers):
torch.testing.assert_close(
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
)
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("num_layers", [8])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("intermediate_size", [256])
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("backend", ["latency", "throughput"])
def test_eplb_fml(
world_size: int,
num_layers: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
num_tokens: int,
backend: str,
monkeypatch,
):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
num_local_experts = num_experts // world_size
num_topk = 4
test_config = TestConfig(
num_layers=num_layers,
num_experts=num_experts,
num_local_experts=num_local_experts,
num_topk=num_topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_tokens=num_tokens,
)
distributed_run(
_test_eplb_fml,
world_size,
test_config,
)

View File

@ -350,21 +350,35 @@ def test_human_readable_model_len():
assert args.max_model_len == 1_000_000
args = parser.parse_args(["--max-model-len", "10k"])
assert args.max_model_len == 10_000
args = parser.parse_args(["--max-model-len", "2g"])
assert args.max_model_len == 2_000_000_000
args = parser.parse_args(["--max-model-len", "2t"])
assert args.max_model_len == 2_000_000_000_000
# Capital
args = parser.parse_args(["--max-model-len", "3K"])
assert args.max_model_len == 1024 * 3
assert args.max_model_len == 2**10 * 3
args = parser.parse_args(["--max-model-len", "10M"])
assert args.max_model_len == 2**20 * 10
args = parser.parse_args(["--max-model-len", "4G"])
assert args.max_model_len == 2**30 * 4
args = parser.parse_args(["--max-model-len", "4T"])
assert args.max_model_len == 2**40 * 4
# Decimal values
args = parser.parse_args(["--max-model-len", "10.2k"])
assert args.max_model_len == 10200
# ..truncated to the nearest int
args = parser.parse_args(["--max-model-len", "10.212345k"])
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
assert args.max_model_len == 10212
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
assert args.max_model_len == 10212345
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
assert args.max_model_len == 10212345123
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
assert args.max_model_len == 10212345123456
# Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError):
args = parser.parse_args(["--max-model-len", invalid])
parser.parse_args(["--max-model-len", invalid])

View File

@ -1,21 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
from openai.types.responses.response_output_item import McpCall
from openai_harmony import Author, Message, Role, TextContent
from tests.entrypoints.openai.utils import verify_harmony_messages
from vllm.entrypoints.openai.parser.harmony_utils import (
auto_drop_analysis_messages,
get_encoding,
has_custom_tools,
parse_chat_input_to_harmony_message,
parse_chat_output,
parse_input_to_harmony_message,
parse_output_message,
)
class TestParseInputToHarmonyMessage:
"""Tests for parse_input_to_harmony_message function."""
class TestCommonParseInputToHarmonyMessage:
"""
Tests for scenarios that are common to both Chat Completion
parse_chat_input_to_harmony_message and Responsees API
parse_input_to_harmony_message functions.
"""
def test_assistant_message_with_tool_calls(self):
@pytest.fixture(
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
)
def parse_function(self, request):
return request.param
def test_assistant_message_with_tool_calls(self, parse_function):
"""Test parsing assistant message with tool calls."""
chat_msg = {
"role": "assistant",
@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_function(chat_msg)
assert len(messages) == 2
@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
assert messages[1].recipient == "functions.search_web"
assert messages[1].content_type == "json"
def test_assistant_message_with_empty_tool_call_arguments(self):
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
"""Test parsing assistant message with tool call having None arguments."""
chat_msg = {
"role": "assistant",
@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
assert messages[0].recipient == "functions.get_current_time"
def test_system_message(self, parse_function):
"""Test parsing system message."""
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
def test_developer_message(self, parse_function):
"""Test parsing developer message."""
chat_msg = {
"role": "developer",
"content": "Use concise language",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
def test_user_message_with_string_content(self, parse_function):
"""Test parsing user message with string content."""
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
def test_user_message_with_array_content(self, parse_function):
"""Test parsing user message with array content."""
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
def test_assistant_message_with_string_content(self, parse_function):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
def test_pydantic_model_input(self, parse_function):
"""Test parsing Pydantic model input (has model_dump method)."""
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
}
chat_msg = MockPydanticModel()
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
def test_tool_call_with_missing_function_fields(self, parse_function):
"""Test parsing tool call with missing name or arguments."""
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
}
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
def test_array_content_with_missing_text(self, parse_function):
"""Test parsing array content where text field is missing."""
chat_msg = {
"role": "user",
"content": [
{}, # Missing text field
{"text": "actual text"},
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
class TestParseInputToHarmonyMessage:
"""
Tests for scenarios that are specific to the Responses API
parse_input_to_harmony_message function.
"""
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
chat_msg = {
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
def test_tool_message_with_string_content(self):
"""Test parsing tool message with string content."""
chat_msg = {
@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.search_results"
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
def test_tool_message_with_empty_content(self):
@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.empty_tool"
assert messages[0].content[0].text == ""
def test_system_message(self):
"""Test parsing system message."""
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
}
messages = parse_input_to_harmony_message(chat_msg)
class TestParseChatInputToHarmonyMessage:
"""
Tests for scenarios that are specific to the Chat Completion API
parse_chat_input_to_harmony_message function.
"""
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
def test_developer_message(self):
"""Test parsing developer message."""
chat_msg = {
"role": "developer",
"content": "Use concise language",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
def test_user_message_with_string_content(self):
"""Test parsing user message with string content."""
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
def test_user_message_with_array_content(self):
"""Test parsing user message with array content."""
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
def test_assistant_message_with_string_content(self):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
def test_pydantic_model_input(self):
"""Test parsing Pydantic model input (has model_dump method)."""
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
}
chat_msg = MockPydanticModel()
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
def test_user_message_with_empty_content(self):
chat_msg = {
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
verify_harmony_messages(
messages,
[
{
"role": "user",
"content": "",
},
],
)
def test_tool_call_with_missing_function_fields(self):
"""Test parsing tool call with missing name or arguments."""
def test_user_message_with_none_content(self):
chat_msg = {
"role": "user",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "user",
"content": "",
},
],
)
def test_assistant_message_with_empty_content(self):
chat_msg = {
"role": "assistant",
"content": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 0
def test_assistant_message_with_none_content(self):
chat_msg = {
"role": "assistant",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 0
def test_assistant_message_with_content_but_empty_reasoning(self):
chat_msg = {
"role": "assistant",
"content": "The answer is 4.",
"reasoning": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "final",
"content": "The answer is 4.",
},
],
)
def test_assistant_message_with_reasoning_but_empty_content(self):
chat_msg = {
"role": "assistant",
"reasoning": "I'm thinking about the user's question.",
"content": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I'm thinking about the user's question.",
},
],
)
def test_assistant_message_with_reasoning_but_none_content(self):
chat_msg = {
"role": "assistant",
"reasoning": "I'm thinking about the user's question.",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I'm thinking about the user's question.",
},
],
)
def test_assistant_message_with_tool_calls_but_no_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_array_content_with_missing_text(self):
"""Test parsing array content where text field is missing."""
def test_assistant_message_with_tool_calls_and_content(self):
chat_msg = {
"role": "user",
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"content": "I'll call the tool.",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"content": "I'll call the tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_assistant_message_with_tool_calls_and_reasoning(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"reasoning": "I should use the get_weather tool.",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I should use the get_weather tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"reasoning": "I should use the get_weather tool.",
"content": "I'll call the tool.",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"content": "I'll call the tool.",
},
{
"role": "assistant",
"channel": "analysis",
"content": "I should use the get_weather tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_tool_message_with_string_content(self):
tool_id_names = {
"call_123": "get_weather",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": "The weather in San Francisco is sunny, 72°F",
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.get_weather",
"content": "The weather in San Francisco is sunny, 72°F",
"channel": "commentary",
},
],
)
def test_tool_message_with_array_content(self):
tool_id_names = {
"call_123": "search_results",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": [
{}, # Missing text field
{"text": "actual text"},
{"type": "text", "text": "Result 1: "},
{"type": "text", "text": "Result 2: "},
{
"type": "image",
"url": "http://example.com/img.png",
}, # Should be ignored
{"type": "text", "text": "Result 3"},
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.search_results",
"content": "Result 1: Result 2: Result 3",
"channel": "commentary",
},
],
)
def test_tool_message_with_empty_content(self):
tool_id_names = {
"call_123": "empty_tool",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": "",
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.empty_tool",
"content": "",
"channel": "commentary",
},
],
)
def test_tool_message_with_none_content(self):
tool_id_names = {
"call_123": "empty_tool",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": None,
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.empty_tool",
"content": "",
"channel": "commentary",
},
],
)
class TestAutoDropAnalysisMessages:
def test_no_analysis_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_only_analysis_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_multiple_analysis_messages_without_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_only_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_drops_one_analysis_messages_before_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the first analysis message
assert cleaned_messages == messages[1:]
def test_drops_all_analysis_messages_before_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the first 3 analysis messages
assert cleaned_messages == messages[3:]
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 5."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped all those analysis messages
assert len(cleaned_messages) == 2
assert cleaned_messages[0].content[0].text == "The answer is 4."
assert cleaned_messages[1].content[0].text == "The answer is 5."
def test_drops_non_assistant_analysis_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.TOOL, "The tool thinks we should think harder."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the analysis message
assert cleaned_messages == messages[1:]
class TestParseChatOutput:
def test_parse_chat_output_interrupted_first_message(self) -> None:
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "I'm in the middle of answering"
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I'm in the middle of thinking"
assert final_content is None
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
"<|start|>assistant<|channel|>final"
"<|message|>I'm in the middle of answering"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I'm thinking."
assert final_content == "I'm in the middle of answering"
def test_parse_chat_output_complete_content(self) -> None:
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "The answer is 4."
def test_parse_chat_output_complete_commentary(self) -> None:
harmony_str = (
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "I need to call some tools."
def test_parse_chat_output_complete_reasoning(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I've thought hard about this."
assert final_content is None
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I've thought hard about this."
assert final_content == "The answer is 4."
class TestParseOutputMessage:

View File

@ -79,9 +79,12 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
assert chunk_count > 0
assert first_chunk is not None, "message_start chunk was never observed"
assert first_chunk.usage is not None, "first chunk should include usage stats"
assert first_chunk.usage["output_tokens"] == 0
assert first_chunk.usage["input_tokens"] > 5
assert first_chunk.message is not None, "first chunk should include message"
assert first_chunk.message.usage is not None, (
"first chunk should include usage stats"
)
assert first_chunk.message.usage.output_tokens == 0
assert first_chunk.message.usage.input_tokens > 5
@pytest.mark.asyncio

View File

@ -11,13 +11,25 @@ import pytest_asyncio
from openai import OpenAI
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
from ...utils import RemoteOpenAIServer
from .utils import (
accumulate_streaming_response,
verify_chat_response,
verify_harmony_messages,
)
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
# Verify that data_parallel_rank defaults to None
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
class TestServingChatWithHarmony:
"""
These tests ensure Chat Completion requests are being properly converted into
Harmony messages and Harmony response messages back into Chat Completion responses.
These tests are not exhaustive, but each one was created to cover a specific case
that we got wrong but is now fixed.
Any changes to the tests and their expectations may result in changes to the
accuracy of model prompting and responses generated. It is suggested to run
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
any impact of changes in how we prompt Harmony models.
"""
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
def stream(self, request) -> bool:
"""Parameterize tests to run in both non-streaming and streaming modes."""
return request.param
@pytest.fixture()
def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
return mock_engine
@pytest.fixture()
def serving_chat(self, mock_engine) -> OpenAIServingChat:
chat = _build_serving_chat(mock_engine)
chat.use_harmony = True
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
return chat
def mock_request_output_from_req_and_token_ids(
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
) -> RequestOutput:
# Our tests don't use most fields, so just get the token ids correct
completion_output = CompletionOutput(
index=0,
text="",
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
)
return RequestOutput(
request_id=req.request_id,
prompt=[],
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[completion_output],
finished=finished,
)
@pytest.fixture
def weather_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
},
},
},
]
@pytest.fixture
def weather_messages_start(self) -> list[dict[str, Any]]:
return [
{
"role": "user",
"content": "What's the weather like in Paris today?",
},
]
async def generate_response_from_harmony_str(
self,
serving_chat: OpenAIServingChat,
req: ChatCompletionRequest,
harmony_str: str,
stream: bool = False,
) -> ChatCompletionResponse:
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
async def result_generator():
if stream:
for token_id in harmony_token_ids:
yield self.mock_request_output_from_req_and_token_ids(
req, [token_id]
)
yield self.mock_request_output_from_req_and_token_ids(
req, [], finished=True
)
else:
yield self.mock_request_output_from_req_and_token_ids(
req, harmony_token_ids, finished=True
)
generator_func = (
serving_chat.chat_completion_stream_generator
if stream
else serving_chat.chat_completion_full_generator
)
result = generator_func(
request=req,
result_generator=result_generator(),
request_id=req.request_id,
model_name=req.model,
conversation=[],
tokenizer=get_tokenizer(req.model),
request_metadata=RequestResponseMetadata(
request_id=req.request_id,
model_name=req.model,
),
)
if stream:
return await accumulate_streaming_response(result)
return await result
@pytest.mark.asyncio
async def test_simple_chat(self, serving_chat, stream):
messages = [{"role": "user", "content": "what is 1+1?"}]
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "We need to think really hard about this."
final_str = "The answer is 2."
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
# The analysis message should be dropped on subsequent inputs because
# of the subsequent assistant message to the final channel.
{"role": "assistant", "channel": "final", "content": final_str},
],
)
@pytest.mark.asyncio
async def test_tool_call_response_with_content(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
commentary_str = "We'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
content=commentary_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"content": commentary_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_multi_turn_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
paris_tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", paris_tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
# Test the Chat Completion response for the second turn's output
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
response_2 = await self.generate_response_from_harmony_str(
serving_chat, req_2, response_str, stream=stream
)
verify_chat_response(response_2, content=paris_weather_str)
# Add the output messages from the second turn as input to the third turn
for choice in response_2.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add a new user message for the third turn
messages.append(
{
"role": "user",
"content": "What's the weather like in Boston today?",
},
)
# Test the Harmony messages for the third turn's input
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3)
verify_harmony_messages(
input_messages_3,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
{
"role": "assistant",
"channel": "final",
"content": paris_weather_str,
},
{"role": "user", "content": messages[-1]["content"]},
],
)
# Test the Chat Completion response for the third turn's output
reasoning_str = "I'll call get_weather."
boston_tool_args_str = '{"location": "Boston"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
)
response_3 = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response_3,
reasoning=reasoning_str,
tool_calls=[("get_weather", boston_tool_args_str)],
)
tool_call = response_3.choices[0].message.tool_calls[0]
# Add the output messages from the third turn as input to the fourth turn
for choice in response_3.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "10 degrees Celsius",
},
)
# Test the Harmony messages for the fourth turn's input
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4)
verify_harmony_messages(
input_messages_4,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{"role": "assistant"},
{"role": "tool"},
{
"role": "assistant",
"channel": "final",
},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": boston_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "10 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "4",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
# The reasoning that would have resulted in an analysis message is
# dropped because of a later assistant message to the final channel.
{
"role": "assistant",
"channel": "final",
"content": messages[1]["content"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": [],
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator
from typing import Any
from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
UsageInfo,
)
async def accumulate_streaming_response(
stream_generator: AsyncGenerator[str, None],
) -> ChatCompletionResponse:
"""
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
This helper parses the SSE format and builds up the complete response
by combining all the delta chunks.
"""
accumulated_content = ""
accumulated_reasoning = None
accumulated_tool_calls: list[dict[str, Any]] = []
role = None
finish_reason = None
response_id = None
created = None
model = None
index = 0
async for chunk_str in stream_generator:
# Skip empty lines and [DONE] marker
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
continue
# Parse SSE format: "data: {json}\n\n"
if chunk_str.startswith("data: "):
json_str = chunk_str[6:].strip()
try:
chunk_data = json.loads(json_str)
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
chunk = ChatCompletionStreamResponse(**chunk_data)
# Store metadata from first chunk
if response_id is None:
response_id = chunk.id
created = chunk.created
model = chunk.model
# Process each choice in the chunk
for choice in chunk.choices:
if choice.delta.role:
role = choice.delta.role
if choice.delta.content:
accumulated_content += choice.delta.content
if choice.delta.reasoning:
if accumulated_reasoning is None:
accumulated_reasoning = ""
accumulated_reasoning += choice.delta.reasoning
if choice.delta.tool_calls:
# Accumulate tool calls
for tool_call_delta in choice.delta.tool_calls:
# Find or create the tool call at this index
while len(accumulated_tool_calls) <= tool_call_delta.index:
accumulated_tool_calls.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
if tool_call_delta.id:
accumulated_tool_calls[tool_call_delta.index]["id"] = (
tool_call_delta.id
)
if tool_call_delta.function:
if tool_call_delta.function.name:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["name"] += tool_call_delta.function.name
if tool_call_delta.function.arguments:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["arguments"] += tool_call_delta.function.arguments
if choice.finish_reason:
finish_reason = choice.finish_reason
if choice.index is not None:
index = choice.index
except json.JSONDecodeError:
continue
# Build the final message
message_kwargs = {
"role": role or "assistant",
"content": accumulated_content if accumulated_content else None,
"reasoning": accumulated_reasoning,
}
# Only include tool_calls if there are any
if accumulated_tool_calls:
message_kwargs["tool_calls"] = [
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
for tc in accumulated_tool_calls
]
message = ChatMessage(**message_kwargs)
# Build the final response
choice = ChatCompletionResponseChoice(
index=index,
message=message,
finish_reason=finish_reason or "stop",
)
# Create usage info (with dummy values for tests)
usage = UsageInfo(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
response = ChatCompletionResponse(
id=response_id or "chatcmpl-test",
object="chat.completion",
created=created or 0,
model=model or "test-model",
choices=[choice],
usage=usage,
)
return response
def verify_harmony_messages(
messages: list[Any], expected_messages: list[dict[str, Any]]
):
assert len(messages) == len(expected_messages)
for msg, expected in zip(messages, expected_messages):
if "role" in expected:
assert msg.author.role == expected["role"]
if "author_name" in expected:
assert msg.author.name == expected["author_name"]
if "channel" in expected:
assert msg.channel == expected["channel"]
if "recipient" in expected:
assert msg.recipient == expected["recipient"]
if "content" in expected:
assert msg.content[0].text == expected["content"]
if "content_type" in expected:
assert msg.content_type == expected["content_type"]
if "tool_definitions" in expected:
# Check that the tool definitions match the expected list of tool names
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
assert actual_tools == expected["tool_definitions"]
def verify_chat_response(
response: ChatCompletionResponse,
content: str | None = None,
reasoning: str | None = None,
tool_calls: list[tuple[str, str]] | None = None,
):
assert len(response.choices) == 1
message = response.choices[0].message
if content is not None:
assert message.content == content
else:
assert not message.content
if reasoning is not None:
assert message.reasoning == reasoning
else:
assert not message.reasoning
if tool_calls:
assert message.tool_calls is not None
assert len(message.tool_calls) == len(tool_calls)
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
assert tc.function.name == expected_name
assert tc.function.arguments == expected_args
else:
assert not message.tool_calls

View File

@ -7,6 +7,7 @@ import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
@ -22,6 +23,10 @@ QDTYPES = (
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
# 0: use 2D kernel for decode
# 8: use 3D kernel for decode
SEQ_THRESHOLD_3D_VALUES = [0, 8]
def ref_paged_attn(
query: torch.Tensor,
@ -92,6 +97,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
@ -103,6 +109,7 @@ def test_triton_unified_attn(
soft_cap: float | None,
num_blocks: int,
q_dtype: torch.dtype | None,
seq_threshold_3D: int,
) -> None:
torch.set_default_device("cuda")
@ -152,6 +159,21 @@ def test_triton_unified_attn(
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)
num_par_softmax_segments = 16
head_size_padded = next_power_of_2(head_size)
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
@ -169,6 +191,11 @@ def test_triton_unified_attn(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)
ref_output = ref_paged_attn(

View File

@ -116,7 +116,6 @@ def test_mrope(
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters,
@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
is_neox_style=is_neox_style,
rope_parameters=config.rope_parameters,

View File

@ -83,8 +83,12 @@ def test_rotary_embedding(
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
rope_parameters = {
"rope_type": "default",
"rope_theta": rope_theta,
"partial_rotary_factor": rotary_dim / head_size,
}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=torch.get_default_device())
positions = torch.randint(0, max_position, (batch_size, seq_len))
@ -150,9 +154,9 @@ def test_rope_module_cache():
if rotary_dim is None:
rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope(
head_size,
rotary_dim,
max_position,
is_neox_style,
rope_parameters,
@ -177,9 +181,9 @@ def test_rope_module_cache():
if rotary_dim is None:
rotary_dim = head_size
rope_parameters["rope_theta"] = rope_theta
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
rope = get_rope(
head_size,
rotary_dim,
max_position,
is_neox_style,
rope_parameters,

View File

@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128]
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
@pytest.mark.parametrize("topk", [2, 4])
def test_batched_deepgemm_vs_triton(
E: int, T: int, K: int, N: int, topk: int, monkeypatch
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""

View File

@ -248,6 +248,7 @@ def test_fused_moe_batched_experts(
per_act_token_quant: bool,
block_shape: list[int] | None,
input_scales: bool,
workspace_init,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""

View File

@ -137,7 +137,7 @@ def setup_cuda():
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")

View File

@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
workspace_init,
ep_size: int | None = None,
):
current_platform.seed_everything(7)
@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
workspace_init,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP(
per_out_channel: bool,
ep_size: int,
monkeypatch,
workspace_init,
):
test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
m,
n,
k,
e,
topk,
per_act_token,
per_out_channel,
monkeypatch,
workspace_init,
ep_size,
)
@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large(
per_out_channel: bool,
ep_size: int,
monkeypatch,
workspace_init,
):
test_cutlass_moe_8_bit_no_graph(
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
m,
n,
k,
e,
topk,
per_act_token,
per_out_channel,
monkeypatch,
workspace_init,
ep_size,
)
@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8(
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
workspace_init,
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):

View File

@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import (
is_deep_gemm_supported,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
current_platform.seed_everything(pgi.rank)
w1 = w1.to(device=torch.cuda.current_device())
@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe(
topk: int,
world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
workspace_init,
):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe(
block_size: list[int],
world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
workspace_init,
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
@ -342,6 +343,9 @@ def _deep_ep_moe(
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
if not low_latency_mode:
assert not use_fp8_dispatch, (
"FP8 dispatch interface is available only in low-latency mode"
@ -437,6 +441,7 @@ def test_deep_ep_moe(
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
workspace_init,
):
low_latency_mode = False
use_fp8_dispatch = False
@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe(
topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool,
workspace_init,
):
low_latency_mode = True

View File

@ -143,7 +143,7 @@ NUM_EXPERTS = [32]
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")

View File

@ -206,6 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk: int,
activation: str,
monkeypatch,
workspace_init,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")

View File

@ -51,7 +51,14 @@ MNK_FACTORS = [
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
activation: str,
workspace_init,
):
current_platform.seed_everything(7)
with set_current_vllm_config(

View File

@ -269,7 +269,7 @@ class Case:
)
@pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp):
def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
from triton_kernels.tensor_details import layout
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):

View File

@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.worker.workspace import init_workspace_manager
from .modular_kernel_tools.common import (
Config,
@ -77,6 +78,10 @@ def rank_worker(
weights: WeightTensors,
verbose: bool,
):
# Initialize workspace manager in child process
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
current_platform.seed_everything(pgi.rank)
# sanity check
@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu(
chunk_size: int | None,
world_size: int,
pytestconfig,
workspace_init,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""

View File

@ -209,6 +209,7 @@ def test_oai_triton_moe(
num_experts: int,
topk: int,
unfused: bool,
workspace_init,
):
current_platform.seed_everything(0)
(

View File

@ -231,6 +231,7 @@ def test_fused_moe(
padding: bool,
chunk_size: int,
monkeypatch,
workspace_init,
):
current_platform.seed_everything(7)

View File

@ -40,7 +40,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
):
current_platform.seed_everything(7)
with set_current_vllm_config(

View File

@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.v1.worker.workspace import init_workspace_manager
from ...utils import multi_gpu_test
from .parallel_utils import ProcessGroupInfo, parallel_launch
@ -181,6 +182,7 @@ def test_fused_moe_batched_experts(
e: int,
topk: int,
dtype: torch.dtype,
workspace_init,
):
current_platform.seed_everything(7)
@ -863,6 +865,9 @@ def _pplx_test_loop(
make_weights: bool,
test_fn: Callable,
):
device = torch.device(f"cuda:{pgi.local_rank}")
init_workspace_manager(device)
def format_result(msg, ex=None):
if ex is not None:
x = str(ex)

View File

@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant(
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
qtype_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.max
)
qtype_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else qtype_traits.min
use_fp8fnuz = (
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
)
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)

View File

@ -18,7 +18,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
IS_SUPPORTED_BY_GPU = (
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
)
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:

View File

@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ScaledMM kernel selection logic (CPU-only)
Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
"""
import inspect
from abc import ABC
import pytest
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
)
pytestmark = pytest.mark.cpu_test
def test_is_supported_is_abstract():
"""Test that is_supported() is properly defined as abstract."""
assert issubclass(ScaledMMLinearKernel, ABC)
assert hasattr(ScaledMMLinearKernel, "is_supported")
def test_cpu_kernel_implements_is_supported():
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
"CPUScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
CPUScaledMMLinearKernel.is_supported
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
# Verify it can be called as a classmethod
result, reason = CPUScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
def test_aiter_kernel_implements_is_supported():
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
"AiterScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert inspect.ismethod(
AiterScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
# (will return False on CPU, which is expected)
result, reason = AiterScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm
# This validates the method works correctly even on non-ROCm platforms
def test_cpu_kernel_accepts_all_configs():
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
configs = [
ScaledMMLinearLayerConfig(
is_channelwise=False,
is_static_input_scheme=True,
input_symmetric=True,
),
ScaledMMLinearLayerConfig(
is_channelwise=True,
is_static_input_scheme=False,
input_symmetric=False,
),
]
for config in configs:
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
assert can_impl, (
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
)

View File

@ -17,7 +17,6 @@ def test_idefics_multimodal(
with vllm_runner(
model_name="HuggingFaceM4/Idefics3-8B-Llama3",
runner="pooling",
task="classify",
convert="classify",
load_format="dummy",
max_model_len=512,
@ -86,7 +85,6 @@ def test_gemma_multimodal(
with vllm_runner(
model_name="google/gemma-3-4b-it",
runner="pooling",
task="classify",
convert="classify",
load_format="auto",
hf_overrides=update_config,

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets, model_id: str
):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs: dict[str, object] = {}
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
max_image_size = processor.info.get_image_size_with_most_features()
max_tokens = processor.info.get_num_image_tokens(
image_width=max_image_size.width,
image_height=max_image_size.height,
processor=hf_processor,
)
prompt = "<start_of_image>"
image_seq_length = hf_processor.image_seq_length
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
mm_kwargs_data = processed_inputs["mm_kwargs"].get_data()
num_patches_tensor = mm_kwargs_data["num_patches"]
tokens = int(num_patches_tensor.item()) * image_seq_length
assert tokens <= max_tokens

View File

@ -53,3 +53,38 @@ def test_processor_override(
assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
assert pixel_shape[1] == expected_pixels_shape[1]
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets,
model_id: str,
max_pixels: int,
):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"max_pixels": max_pixels},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs: dict[str, object] = {}
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size
max_image_size = processor.info.get_image_size_with_most_features()
max_tokens = processor.info.get_num_image_tokens(
image_width=max_image_size.width,
image_height=max_image_size.height,
image_processor=hf_processor.image_processor,
)
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist()
t, h, w = grid_thw[0]
tokens = (t * h * w) // (merge_size**2)
assert tokens < max_tokens

View File

@ -173,10 +173,7 @@ class _HfExamplesInfo:
_TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only]
"AfmoeForCausalLM": _HfExamplesInfo(
"arcee-ai/Trinity-Nano",
is_available_online=False,
),
"AfmoeForCausalLM": _HfExamplesInfo("arcee-ai/Trinity-Nano-Preview"),
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),

View File

@ -147,7 +147,7 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
"""
Regression test for handling videos with broken frames.
This test uses a pre-corrupted video file (assets/corrupted.mp4) that
contains broken/unreadable frames to verify the video loader handles
contains broken frames to verify the video loader handles
them gracefully without crashing and returns accurate metadata.
"""
with monkeypatch.context() as m:
@ -177,3 +177,125 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
f"Expected fewer than {metadata['total_num_frames']} frames, "
f"but loaded {frames.shape[0]} frames"
)
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_1")
class TestVideoBackendOverride1(VideoLoader):
"""Test loader that returns FAKE_OUTPUT_1 to verify backend selection."""
@classmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict]:
return FAKE_OUTPUT_1, {"video_backend": "test_video_backend_override_1"}
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_2")
class TestVideoBackendOverride2(VideoLoader):
"""Test loader that returns FAKE_OUTPUT_2 to verify backend selection."""
@classmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict]:
return FAKE_OUTPUT_2, {"video_backend": "test_video_backend_override_2"}
def test_video_media_io_backend_kwarg_override(monkeypatch: pytest.MonkeyPatch):
"""
Test that video_backend kwarg can override the VLLM_VIDEO_LOADER_BACKEND
environment variable.
This allows users to dynamically select a different video backend
via --media-io-kwargs without changing the global env var, which is
useful when plugins set a default backend but a specific request
needs a different one.
"""
with monkeypatch.context() as m:
# Set the env var to one backend
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_1")
imageio = ImageMediaIO()
# Without video_backend kwarg, should use env var backend
videoio_default = VideoMediaIO(imageio, num_frames=10)
frames_default, metadata_default = videoio_default.load_bytes(b"test")
np.testing.assert_array_equal(frames_default, FAKE_OUTPUT_1)
assert metadata_default["video_backend"] == "test_video_backend_override_1"
# With video_backend kwarg, should override env var
videoio_override = VideoMediaIO(
imageio, num_frames=10, video_backend="test_video_backend_override_2"
)
frames_override, metadata_override = videoio_override.load_bytes(b"test")
np.testing.assert_array_equal(frames_override, FAKE_OUTPUT_2)
assert metadata_override["video_backend"] == "test_video_backend_override_2"
def test_video_media_io_backend_kwarg_not_passed_to_loader(
monkeypatch: pytest.MonkeyPatch,
):
"""
Test that video_backend kwarg is consumed by VideoMediaIO and NOT passed
through to the underlying video loader's load_bytes method.
This ensures the kwarg is properly popped from kwargs before forwarding.
"""
@VIDEO_LOADER_REGISTRY.register("test_reject_video_backend_kwarg")
class RejectVideoBackendKwargLoader(VideoLoader):
"""Test loader that fails if video_backend is passed through."""
@classmethod
def load_bytes(
cls, data: bytes, num_frames: int = -1, **kwargs
) -> tuple[npt.NDArray, dict]:
# This should never receive video_backend in kwargs
if "video_backend" in kwargs:
raise AssertionError(
"video_backend should be consumed by VideoMediaIO, "
"not passed to loader"
)
return FAKE_OUTPUT_1, {"received_kwargs": list(kwargs.keys())}
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_reject_video_backend_kwarg")
imageio = ImageMediaIO()
# Even when video_backend is provided, it should NOT be passed to loader
videoio = VideoMediaIO(
imageio,
num_frames=10,
video_backend="test_reject_video_backend_kwarg",
other_kwarg="should_pass_through",
)
# This should NOT raise AssertionError
frames, metadata = videoio.load_bytes(b"test")
np.testing.assert_array_equal(frames, FAKE_OUTPUT_1)
# Verify other kwargs are still passed through
assert "other_kwarg" in metadata["received_kwargs"]
def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch):
"""
Test that when video_backend kwarg is None or not provided,
VideoMediaIO falls back to VLLM_VIDEO_LOADER_BACKEND env var.
"""
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_2")
imageio = ImageMediaIO()
# Explicit None should fall back to env var
videoio_none = VideoMediaIO(imageio, num_frames=10, video_backend=None)
frames_none, metadata_none = videoio_none.load_bytes(b"test")
np.testing.assert_array_equal(frames_none, FAKE_OUTPUT_2)
assert metadata_none["video_backend"] == "test_video_backend_override_2"
# Not providing video_backend should also fall back to env var
videoio_missing = VideoMediaIO(imageio, num_frames=10)
frames_missing, metadata_missing = videoio_missing.load_bytes(b"test")
np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2)
assert metadata_missing["video_backend"] == "test_video_backend_override_2"

View File

@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "minimax_m2_append_think"
end_token = "</think>"
# MiniMax M2 model path
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
@pytest.fixture(scope="module")
def minimax_m2_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# =============================================================================
# MiniMaxM2AppendThinkReasoningParser behavior:
# - Prepends <think> to the beginning of the output
# - Does NOT separate reasoning and content
# - Returns everything as content (with <think> prepended)
# - reasoning is always None
#
# This parser is used when you want to keep the raw output with <think> added
# =============================================================================
# Case: simple output with end token
SIMPLE_OUTPUT = {
"output": "This is reasoning</think>This is response",
"reasoning": None,
"content": "<think>This is reasoning</think>This is response",
"is_reasoning_end": True,
}
# Case: output without end token (reasoning in progress)
NO_END_TOKEN = {
"output": "This is reasoning in progress",
"reasoning": None,
"content": "<think>This is reasoning in progress",
"is_reasoning_end": False,
}
# Case: only end token
ONLY_END_TOKEN = {
"output": "</think>This is response",
"reasoning": None,
"content": "<think></think>This is response",
"is_reasoning_end": True,
}
# Case: multiple lines
MULTIPLE_LINES = {
"output": "Line 1\nLine 2</think>Response 1\nResponse 2",
"reasoning": None,
"content": "<think>Line 1\nLine 2</think>Response 1\nResponse 2",
"is_reasoning_end": True,
}
# Case: empty output (non-streaming prepends <think>)
EMPTY = {
"output": "",
"reasoning": None,
"content": "<think>",
"is_reasoning_end": False,
}
# Case: empty output streaming (no tokens = no output)
EMPTY_STREAMING = {
"output": "",
"reasoning": None,
"content": None,
"is_reasoning_end": False,
}
# Case: special characters
SPECIAL_CHARS = {
"output": "Let me think... 1+1=2</think>Yes!",
"reasoning": None,
"content": "<think>Let me think... 1+1=2</think>Yes!",
"is_reasoning_end": True,
}
# Case: code in output
CODE_OUTPUT = {
"output": "```python\nprint('hi')\n```</think>Here's the code.",
"reasoning": None,
"content": "<think>```python\nprint('hi')\n```</think>Here's the code.",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(
False,
SIMPLE_OUTPUT,
id="simple_output",
),
pytest.param(
True,
SIMPLE_OUTPUT,
id="simple_output_streaming",
),
pytest.param(
False,
NO_END_TOKEN,
id="no_end_token",
),
pytest.param(
True,
NO_END_TOKEN,
id="no_end_token_streaming",
),
pytest.param(
False,
ONLY_END_TOKEN,
id="only_end_token",
),
pytest.param(
True,
ONLY_END_TOKEN,
id="only_end_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
False,
EMPTY,
id="empty",
),
pytest.param(
True,
EMPTY_STREAMING,
id="empty_streaming",
),
pytest.param(
False,
SPECIAL_CHARS,
id="special_chars",
),
pytest.param(
True,
SPECIAL_CHARS,
id="special_chars_streaming",
),
pytest.param(
False,
CODE_OUTPUT,
id="code_output",
),
pytest.param(
True,
CODE_OUTPUT,
id="code_output_streaming",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
minimax_m2_tokenizer,
):
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
minimax_m2_tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]

View File

@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "minimax_m2"
end_token = "</think>"
# MiniMax M2 model path
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
@pytest.fixture(scope="module")
def minimax_m2_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
# =============================================================================
# MiniMax M2 specific behavior:
# - Model does NOT generate <think> start token
# - Model only generates </think> end token
# - All content before </think> is reasoning
# - All content after </think> is the actual response (content)
# =============================================================================
# Case: reasoning + end token + content (typical case)
SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
# Case: reasoning + end token only (no content after)
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
# Case: no end token yet (streaming in progress, all is reasoning)
NO_END_TOKEN = {
"output": "This is reasoning in progress",
"reasoning": "This is reasoning in progress",
"content": None,
"is_reasoning_end": False,
}
# Case: multiple lines of reasoning
MULTIPLE_LINES = {
"output": "First line\nSecond line</think>Response first line\nResponse second",
"reasoning": "First line\nSecond line",
"content": "Response first line\nResponse second",
"is_reasoning_end": True,
}
# Case: only end token (empty reasoning, immediate response)
SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the response",
"reasoning": "",
"content": "This is the response",
"is_reasoning_end": True,
}
# Case: only end token streaming (reasoning is None because it's just the token)
SHORTEST_REASONING_STREAMING = {
"output": "</think>This is the response",
"reasoning": None,
"content": "This is the response",
"is_reasoning_end": True,
}
# Case: empty output
EMPTY = {
"output": "",
"reasoning": "",
"content": None,
"is_reasoning_end": False,
}
# Case: empty streaming
EMPTY_STREAMING = {
"output": "",
"reasoning": None,
"content": None,
"is_reasoning_end": False,
}
# Case: long reasoning with special characters
SPECIAL_CHARS = {
"output": "Let me think... 1+1=2, right?</think>Yes, 1+1=2.",
"reasoning": "Let me think... 1+1=2, right?",
"content": "Yes, 1+1=2.",
"is_reasoning_end": True,
}
# Case: reasoning with code blocks
CODE_IN_REASONING = {
"output": "```python\nprint('hello')\n```</think>Here is the code.",
"reasoning": "```python\nprint('hello')\n```",
"content": "Here is the code.",
"is_reasoning_end": True,
}
TEST_CASES = [
# Core cases: no start token (MiniMax M2 actual behavior)
pytest.param(
False,
SIMPLE_REASONING,
id="simple_reasoning",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_reasoning_streaming",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_streaming",
),
pytest.param(
False,
NO_END_TOKEN,
id="no_end_token",
),
pytest.param(
True,
NO_END_TOKEN,
id="no_end_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING,
id="shortest_reasoning",
),
pytest.param(
True,
SHORTEST_REASONING_STREAMING,
id="shortest_reasoning_streaming",
),
pytest.param(
False,
EMPTY,
id="empty",
),
pytest.param(
True,
EMPTY_STREAMING,
id="empty_streaming",
),
pytest.param(
False,
SPECIAL_CHARS,
id="special_chars",
),
pytest.param(
True,
SPECIAL_CHARS,
id="special_chars_streaming",
),
pytest.param(
False,
CODE_IN_REASONING,
id="code_in_reasoning",
),
pytest.param(
True,
CODE_IN_REASONING,
id="code_in_reasoning_streaming",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
minimax_m2_tokenizer,
):
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
minimax_m2_tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == minimax_m2_tokenizer.convert_tokens_to_ids(
minimax_m2_tokenizer.tokenize(param_dict["content"])
)
else:
content = parser.extract_content_ids(output)
assert content == []

View File

@ -18,47 +18,53 @@ def mistral_tokenizer():
return mistral_tokenizer
SIMPLE_REASONING = {
INVALID_SIMPLE_REASONING = {
"output": "This is a reasoning section[/THINK]This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning": None,
"content": "This is a reasoning sectionThis is the rest",
"is_reasoning_end": False,
}
COMPLETE_REASONING = {
INVALID_COMPLETE_REASONING = {
"output": "This is a reasoning section[/THINK]",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
"reasoning": None,
"content": "This is a reasoning section",
"is_reasoning_end": False,
}
NO_CONTENT = {
"output": "This is content",
"reasoning": "This is content",
"output": "[THINK]This is reasoning",
"reasoning": "This is reasoning",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING = {
"output": "This is content",
"reasoning": None,
"content": "This is content",
"is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"reasoning": None,
"content": "This is a reasoning section",
"is_reasoning_end": False,
}
MULTIPLE_LINES = {
INVALID_MULTIPLE_LINES = {
"output": "This\nThat[/THINK]This is the rest\nThat",
"reasoning": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
"reasoning": None,
"content": "This\nThatThis is the rest\nThat",
"is_reasoning_end": False,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "[/THINK]This is the rest",
"reasoning": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING = {
INVALID_SHORTEST_REASONING_NO_STREAMING = {
"output": "[/THINK]This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
"is_reasoning_end": False,
}
INVALID_SHORTEST_REASONING = {
"output": "[/THINK]This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": False,
}
REASONING_WITH_THINK = {
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
@ -78,17 +84,17 @@ MULTIPLE_LINES_WITH_THINK = {
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "[/THINK]This is the rest",
"reasoning": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "[/THINK]This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
"is_reasoning_end": False,
}
INVALID_SHORTEST_REASONING_WITH_THINK = {
"output": "[/THINK]This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": False,
}
THINK_NO_END = {
"output": "[THINK]This is a reasoning section",
@ -98,8 +104,8 @@ THINK_NO_END = {
}
EMPTY = {
"output": "",
"reasoning": "",
"content": None,
"reasoning": None,
"content": "",
"is_reasoning_end": False,
}
EMPTY_STREAMING = {
@ -109,47 +115,48 @@ EMPTY_STREAMING = {
"is_reasoning_end": False,
}
NEW_LINE = {
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
"reasoning": "This is a reasoning section",
"content": "\nThis is the rest",
"content": "Before\n\nThis is the rest",
"is_reasoning_end": True,
}
# Streaming cannot handle new lines at the beginning of the output
# because we need to support [THINK]...[/THINK] and [/THINK]...
# We cannot know if the text before [THINK] is reasoning content
# or not.
NEW_LINE_STREAMING = {
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
"reasoning": "\nThis is a reasoning section",
"content": "\nThis is the rest",
"output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
"reasoning": "This is a reasoning section",
"content": "Before\n\nThis is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(
False,
SIMPLE_REASONING,
id="simple_reasoning",
INVALID_SIMPLE_REASONING,
id="invalid_simple_reasoning",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_reasoning_streaming",
INVALID_SIMPLE_REASONING,
id="invalid_simple_reasoning_streaming",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
INVALID_COMPLETE_REASONING,
id="invalid_complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_streaming",
INVALID_COMPLETE_REASONING,
id="invalid_complete_reasoning_streaming",
),
pytest.param(
False,
NO_CONTENT,
id="no_content_token",
id="no_content",
),
pytest.param(
False,
NO_REASONING,
id="no_reasoning",
),
pytest.param(
True,
@ -158,23 +165,23 @@ TEST_CASES = [
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
INVALID_MULTIPLE_LINES,
id="invalid_multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
INVALID_MULTIPLE_LINES,
id="invalid_multiple_lines_streaming",
),
pytest.param(
True,
SHORTEST_REASONING,
id="shortest",
INVALID_SHORTEST_REASONING,
id="invalid_shortest",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING,
id="shortest_streaming",
INVALID_SHORTEST_REASONING_NO_STREAMING,
id="invalid_shortest_streaming",
),
pytest.param(
False,
@ -208,13 +215,13 @@ TEST_CASES = [
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
id="shortest_with_think",
INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
id="invalid_shortest_with_think",
),
pytest.param(
True,
SHORTEST_REASONING_WITH_THINK,
id="shortest_with_think_streaming",
INVALID_SHORTEST_REASONING_WITH_THINK,
id="invalid_shortest_with_think_streaming",
),
pytest.param(
False,
@ -316,10 +323,26 @@ def test_mistral_reasoning(
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_tokens)
assert content == mistral_tokenizer.tokenizer.encode(
param_dict["content"], bos=False, eos=False
# Handle the case where there are tokens outputted before Thinking.
# This should not occur if the model is well trained and prompted.
if "[THINK]" in param_dict["output"] and not param_dict["output"].startswith(
"[THINK]"
):
before_content = param_dict["output"].split("[THINK]")[0]
before_token_ids = mistral_tokenizer.tokenizer.encode(
before_content, bos=False, eos=False
)
left_to_encode = param_dict["content"][len(before_content) :]
# Normal situation.
else:
before_token_ids = []
left_to_encode = param_dict["content"]
content_tokens = parser.extract_content_ids(output_tokens)
expected_token_ids = before_token_ids + mistral_tokenizer.tokenizer.encode(
left_to_encode, bos=False, eos=False
)
assert content_tokens == expected_token_ids
else:
content = parser.extract_content_ids(output_tokens)
assert content == []

View File

@ -3,12 +3,45 @@
# for users who do not have any compilers installed on their system
set -e
set -x
merge_base_commit=$(git merge-base HEAD origin/main)
echo "Current merge base commit with main: $merge_base_commit"
echo "INFO: current merge base commit with main: $merge_base_commit"
git show --oneline -s $merge_base_commit
# test whether the metadata.json url is valid, retry each 3 minutes up to 5 times
# this avoids cumbersome error messages & manual retries in case the precompiled wheel
# for the given commit is still being built in the release pipeline
meta_json_url="https://wheels.vllm.ai/$merge_base_commit/vllm/metadata.json"
echo "INFO: will use metadata.json from $meta_json_url"
for i in {1..5}; do
echo "Checking metadata.json URL (attempt $i)..."
if curl --fail "$meta_json_url" > metadata.json; then
echo "INFO: metadata.json URL is valid."
# check whether it is valid json by python
if python3 -m json.tool metadata.json; then
echo "INFO: metadata.json is valid JSON. Proceeding with the test."
else
echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
exit 1
fi
break
fi
# failure handling
if [ $i -eq 5 ]; then
echo "ERROR: metadata.json URL is still not valid after 5 attempts."
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists."
echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes."
echo " NOTE: If it fails, please report in #sig-ci channel."
exit 1
else
echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..."
sleep 180
fi
done
set -x
cd /vllm-workspace/
# uninstall vllm
@ -29,6 +62,6 @@ python3 -c 'import vllm'
# Check if the clangd log file was created
if [ ! -f /tmp/changed.file ]; then
echo "changed.file was not created, python only compilation failed"
echo "ERROR: changed.file was not created, python only compilation failed"
exit 1
fi

View File

@ -22,10 +22,14 @@ from tests.v1.attention.utils import (
)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla(
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness(
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer,
)
with set_current_vllm_config(vllm_config):
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer,
)
impl.process_weights_after_loading(dtype)
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
out_buffer = torch.empty(
@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness(
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
def _triton_convert_reference_impl(
req_ids: torch.Tensor,
block_table: torch.Tensor,
token_indices: torch.Tensor,
block_size: int,
num_topk_tokens: int,
HAS_PREFILL_WORKSPACE: bool = False,
prefill_workspace_request_ids: torch.Tensor | None = None,
prefill_workspace_starts: torch.Tensor | None = None,
) -> torch.Tensor:
"""Reference implementation for triton_convert_req_index_to_global_index."""
num_tokens = req_ids.shape[0]
max_blocks_per_req = block_table.shape[1]
result = torch.empty(
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
)
for token_id in range(num_tokens):
req_id = req_ids[token_id].item()
# Determine if this token uses workspace or paged cache
use_prefill_workspace = False
workspace_start = 0
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
assert prefill_workspace_starts is not None
prefill_req_id = prefill_workspace_request_ids[token_id].item()
if prefill_req_id >= 0:
use_prefill_workspace = True
workspace_start = prefill_workspace_starts[prefill_req_id].item()
for idx_id in range(num_topk_tokens):
token_idx = token_indices[token_id, idx_id].item()
if token_idx == -1:
result[token_id, idx_id] = -1
elif use_prefill_workspace:
# Prefill + using prefill workspace: map to workspace offset
result[token_id, idx_id] = workspace_start + token_idx
else:
# Decode: map to paged cache
block_id = token_idx // block_size
if block_id >= max_blocks_per_req:
result[token_id, idx_id] = -1
else:
block_num = block_table[req_id, block_id].item()
offset = token_idx % block_size
result[token_id, idx_id] = block_num * block_size + offset
return result
@pytest.mark.parametrize("block_size", [16, 64, 128])
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_decode_only(
block_size, num_topk_tokens
):
device = torch.device("cuda")
num_tokens = 8
num_requests = 4
max_blocks_per_req = 10
req_id = torch.randint(
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
)
block_table = torch.randint(
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
)
token_indices = torch.randint(
0,
block_size * max_blocks_per_req,
(num_tokens, num_topk_tokens),
dtype=torch.int32,
device=device,
)
# Set some to -1 to test masking
token_indices[0, :10] = -1
token_indices[3, 50:60] = -1
# Set some to out of bounds
token_indices[2, 100:110] = max_blocks_per_req * block_size
token_indices[6, 150:160] = max_blocks_per_req * block_size
result = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
)
reference_result = _triton_convert_reference_impl(
req_id,
block_table,
token_indices,
block_size,
num_topk_tokens,
)
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
device = torch.device("cuda")
num_requests = 4
max_blocks_per_req = 8
num_topk_tokens = 128
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
req_id = torch.tensor(
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
)
prefill_workspace_request_ids = torch.tensor(
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
)
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
block_table = torch.randint(
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
)
token_indices = torch.randint(
0,
block_size * max_blocks_per_req,
(req_id.shape[0], num_topk_tokens),
dtype=torch.int32,
device=device,
)
# Set some to -1 to test masking
token_indices[0, :10] = -1
token_indices[3, 50:60] = -1
# Set some to out of bounds
token_indices[2, 100:110] = max_blocks_per_req * block_size
token_indices[6, 150:160] = max_blocks_per_req * block_size
result = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
HAS_PREFILL_WORKSPACE=True,
prefill_workspace_request_ids=prefill_workspace_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
reference_result = _triton_convert_reference_impl(
req_id,
block_table,
token_indices,
block_size,
num_topk_tokens,
HAS_PREFILL_WORKSPACE=True,
prefill_workspace_request_ids=prefill_workspace_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
@pytest.mark.parametrize(
"seq_lens,max_buf,start,expected",
"seq_lens,max_buf,expected",
[
# Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
(torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would overflow
(torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
# Large buffer with non-zero start
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
(torch.tensor([1, 1, 1]), 10, [(0, 3)]),
# Large buffer
(torch.tensor([4, 4, 4]), 100, [(0, 3)]),
],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
out = split_prefill_chunks(seq_lens, max_buf, start)
def test_split_prefill_chunks(seq_lens, max_buf, expected):
out = split_prefill_chunks(seq_lens, max_buf)
assert out == expected

View File

@ -13,6 +13,7 @@ import torch
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from tests.utils import RemoteOpenAIServer
from vllm.utils.import_utils import has_deep_ep
# Detect Blackwell / B200 (compute capability 10.x)
try:
@ -44,6 +45,7 @@ DEEPEP_BACKENDS = [
]
@pytest.mark.skipif(not has_deep_ep(), reason="These tests require deep_ep to run")
@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS)
@pytest.mark.xfail(
IS_BLACKWELL,

View File

@ -16,6 +16,16 @@ from vllm.platforms import current_platform
MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm."""
if current_platform.is_rocm():
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
@ -280,9 +290,20 @@ def test_speculators_model_integration(
@pytest.mark.parametrize(
["model_setup", "mm_enabled", "enable_chunked_prefill"],
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
),
pytest.param(
(
"eagle3",
@ -292,6 +313,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
@ -305,6 +327,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
@ -318,6 +341,7 @@ def test_speculators_model_integration(
),
False,
True,
"auto",
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
(
@ -329,6 +353,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
),
pytest.param(
(
@ -339,6 +364,7 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
@ -350,6 +376,7 @@ def test_speculators_model_integration(
),
True,
True,
"auto",
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
@ -361,10 +388,12 @@ def test_speculators_model_integration(
),
False,
False,
"auto",
),
],
ids=[
"qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle",
@ -381,6 +410,7 @@ def test_eagle_correctness(
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
@ -389,6 +419,17 @@ def test_eagle_correctness(
"TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
)
if model_impl == "transformers":
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
f"transformers>={required}, but got {installed}"
)
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
@ -424,6 +465,8 @@ def test_eagle_correctness(
m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)
max_model_len = 2048
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
@ -448,6 +491,7 @@ def test_eagle_correctness(
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
@ -493,6 +537,7 @@ def test_mtp_correctness(
m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)
ref_llm = LLM(
model=model_name,

View File

@ -0,0 +1,756 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm.distributed.kv_events import BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
LMCacheKVEvents,
)
from vllm.v1.outputs import KVConnectorOutput
@pytest.fixture
def mock_lmcache_engine_event():
"""Create a mock event object that mimics what the lmcache engine returns."""
class MockEvent:
def __init__(
self,
block_hashes,
parent_block_hash,
token_ids,
lora_id,
block_size,
medium,
):
self.block_hashes = block_hashes
self.parent_block_hash = parent_block_hash
self.token_ids = token_ids
self.lora_id = lora_id
self.block_size = block_size
self.medium = medium
return MockEvent(
block_hashes=["hash1", "hash2"],
parent_block_hash="parent_hash",
token_ids=[1, 2, 3, 4],
lora_id=None,
block_size=16,
medium="GPU",
)
@pytest.fixture
def mock_connector():
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
connector = MagicMock(spec=LMCacheConnectorV1)
connector._kv_cache_events = None
connector._lmcache_engine = MagicMock()
# Make the methods use the real implementation
connector.get_kv_connector_kv_cache_events = (
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
connector, LMCacheConnectorV1
)
)
connector.update_connector_output = (
LMCacheConnectorV1.update_connector_output.__get__(
connector, LMCacheConnectorV1
)
)
connector.take_events = LMCacheConnectorV1.take_events.__get__(
connector, LMCacheConnectorV1
)
return connector
class TestGetKVConnectorKVCacheEvents:
"""Test get_kv_connector_kv_cache_events method."""
def test_returns_none_when_no_events(self, mock_connector):
"""Test that None is returned when lmcache engine has no events."""
mock_connector._lmcache_engine.get_kv_events.return_value = None
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
def test_returns_none_when_empty_list(self, mock_connector):
"""Test that None is returned when lmcache engine returns empty list."""
mock_connector._lmcache_engine.get_kv_events.return_value = []
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
"""Test conversion of a single event from lmcache engine format."""
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
assert result.get_number_of_workers() == 1
events = result.get_all_events()
assert len(events) == 1
assert isinstance(events[0], BlockStored)
assert events[0].block_hashes == ["hash1", "hash2"]
assert events[0].parent_block_hash == "parent_hash"
assert events[0].token_ids == [1, 2, 3, 4]
assert events[0].lora_id is None
assert events[0].block_size == 16
assert events[0].medium == "GPU"
def test_converts_multiple_events(self, mock_connector):
"""Test conversion of multiple events from lmcache engine format."""
class MockEvent:
def __init__(self, i):
self.block_hashes = [f"hash{i}"]
self.parent_block_hash = f"parent{i}"
self.token_ids = [i]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
events = [MockEvent(i) for i in range(5)]
mock_connector._lmcache_engine.get_kv_events.return_value = events
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
converted_events = result.get_all_events()
assert len(converted_events) == 5
for i, event in enumerate(converted_events):
assert isinstance(event, BlockStored)
assert event.block_hashes == [f"hash{i}"]
assert event.parent_block_hash == f"parent{i}"
assert event.token_ids == [i]
def test_preserves_event_attributes(self, mock_connector):
"""Test that all event attributes are correctly preserved."""
class MockEventWithLora:
def __init__(self):
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
self.parent_block_hash = "parent_xyz"
self.token_ids = [100, 200, 300]
self.lora_id = 42
self.block_size = 32
self.medium = "DISK"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventWithLora()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
event = events[0]
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
assert event.parent_block_hash == "parent_xyz"
assert event.token_ids == [100, 200, 300]
assert event.lora_id == 42
assert event.block_size == 32
assert event.medium == "DISK"
def test_handles_none_parent_block_hash(self, mock_connector):
"""Test handling of events with None parent_block_hash."""
class MockEventNoParent:
def __init__(self):
self.block_hashes = ["hash1"]
self.parent_block_hash = None
self.token_ids = [1, 2]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventNoParent()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
assert events[0].parent_block_hash is None
class TestUpdateConnectorOutput:
"""Test update_connector_output method."""
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that method returns early when kv_cache_events is None."""
connector_output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
self, mock_connector
):
"""Test that method returns early when kv_cache_events is not
LMCacheKVEvents."""
# Create a mock object that is not LMCacheKVEvents
fake_events = MagicMock()
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_sets_kv_cache_events_when_none(self, mock_connector):
"""Test that _kv_cache_events is set when it was None."""
kv_events = LMCacheKVEvents(num_workers=1)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1, 2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is kv_events
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
"""Test that events are added when _kv_cache_events already exists."""
# Set up existing events
existing_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
existing_events.add_events([event1])
existing_events.add_events([event1]) # Simulate 2 workers reporting
mock_connector._kv_cache_events = existing_events
# Create new events to add
new_events = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event2])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Check that events were added
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3 # 2 from existing + 1 from new
assert event1 in all_events
assert event2 in all_events
def test_increments_workers_when_kv_cache_events_already_exists(
self, mock_connector
):
"""Test that worker count is incremented correctly."""
# Set up existing events with 2 workers
existing_events = LMCacheKVEvents(num_workers=2)
mock_connector._kv_cache_events = existing_events
# Create new events from 3 workers
new_events = LMCacheKVEvents(num_workers=3)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Worker count should be 2 + 3 = 5
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
def test_multiple_updates(self, mock_connector):
"""Test multiple consecutive updates."""
# First update
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update
events2 = LMCacheKVEvents(num_workers=2)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
events2.add_events([event2])
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Third update
events3 = LMCacheKVEvents(num_workers=1)
event3 = BlockStored(
block_hashes=["hash3"],
parent_block_hash=None,
token_ids=[3],
block_size=16,
lora_id=None,
medium="GPU",
)
events3.add_events([event3])
output3 = KVConnectorOutput(kv_cache_events=events3)
mock_connector.update_connector_output(output3)
# Check final state
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
def test_updates_with_empty_events(self, mock_connector):
"""Test updating with empty event lists."""
# First update with actual events
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update with empty events
events2 = LMCacheKVEvents(num_workers=2)
# No events added
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Should still have the original event
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 1
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
class TestTakeEvents:
"""Test take_events method."""
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that nothing is yielded when _kv_cache_events is None."""
mock_connector._kv_cache_events = None
events = list(mock_connector.take_events())
assert events == []
def test_yields_events_and_clears(self, mock_connector):
"""Test that events are yielded and then cleared."""
# Set up events
kv_events = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event1, event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Check that events were yielded
assert len(events) == 2
assert event1 in events
assert event2 in events
# Check that _kv_cache_events was cleared
assert mock_connector._kv_cache_events is None
def test_aggregates_before_yielding(self, mock_connector):
"""Test that events are aggregated before yielding."""
# Set up events from multiple workers
kv_events = LMCacheKVEvents(num_workers=3)
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
uncommon_event = BlockStored(
block_hashes=["hash_uncommon"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# All 3 workers report common_event
kv_events.add_events([common_event])
kv_events.add_events([common_event])
kv_events.add_events([common_event])
# Only 1 worker reports uncommon_event
kv_events.add_events([uncommon_event])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Only the common event should be yielded
assert len(events) == 1
assert events[0] == common_event
def test_multiple_take_events_calls(self, mock_connector):
"""Test calling take_events multiple times."""
# First call with events
kv_events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events1.add_events([event1])
mock_connector._kv_cache_events = kv_events1
events1 = list(mock_connector.take_events())
assert len(events1) == 1
assert events1[0] == event1
assert mock_connector._kv_cache_events is None
# Second call with no events
events2 = list(mock_connector.take_events())
assert events2 == []
# Third call after adding new events
kv_events2 = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events2.add_events([event2])
mock_connector._kv_cache_events = kv_events2
events3 = list(mock_connector.take_events())
assert len(events3) == 1
assert events3[0] == event2
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
"""Test that nothing is yielded if aggregation removes all events."""
# Set up events from 2 workers with no common events
kv_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# Worker 1 reports event1
kv_events.add_events([event1])
# Worker 2 reports event2
kv_events.add_events([event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# No common events, so nothing should be yielded
assert events == []
assert mock_connector._kv_cache_events is None
class TestIntegrationScenarios:
"""Test integration scenarios."""
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
"""Test a complete workflow from getting events to taking them."""
# Step 1: Get events from lmcache engine
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is not None
assert len(kv_events.get_all_events()) == 1
# Step 2: Update connector output (simulate receiving from worker)
output1 = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output1)
assert mock_connector._kv_cache_events is not None
# Step 3: Take events
taken_events = list(mock_connector.take_events())
assert len(taken_events) == 1
assert mock_connector._kv_cache_events is None
def test_multiple_workers_workflow(self, mock_connector):
"""Test workflow with multiple workers."""
class MockEvent:
def __init__(self, hash_val):
self.block_hashes = [hash_val]
self.parent_block_hash = None
self.token_ids = [1]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
# Worker 1
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker1"),
]
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
mock_connector.update_connector_output(output1)
# Worker 2
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker2"),
]
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
mock_connector.update_connector_output(output2)
# Take events (should only get common events)
taken_events = list(mock_connector.take_events())
# With aggregation, only events reported by both workers should be present
# In this case, hash_common was reported by both
event_hashes = [e.block_hashes[0] for e in taken_events]
assert "hash_common" in event_hashes
def test_empty_workflow(self, mock_connector):
"""Test workflow when there are no events at any stage."""
# Get events returns None
mock_connector._lmcache_engine.get_kv_events.return_value = None
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is None
# Update with None
output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(output)
# Take events
taken_events = list(mock_connector.take_events())
assert taken_events == []
assert mock_connector._kv_cache_events is None
def test_repeated_cycles(self, mock_connector):
"""Test multiple cycles of the complete workflow."""
class MockEvent:
def __init__(self, cycle_num):
self.block_hashes = [f"hash_cycle_{cycle_num}"]
self.parent_block_hash = None
self.token_ids = [cycle_num]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
for cycle in range(3):
# Get events
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent(cycle)
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
# Update
output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output)
# Take
taken_events = list(mock_connector.take_events())
# Verify
assert len(taken_events) == 1
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
assert mock_connector._kv_cache_events is None
def test_lmcache_kv_events_aggregation(self):
"""
Test LMCacheKVEvents aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput
# Create KVOutputAggregator for 3 workers (simulating TP=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
# Define common and unique events
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash="parent_common",
token_ids=[1, 2, 3],
block_size=16,
lora_id=None,
medium="GPU",
)
worker1_unique_event = BlockStored(
block_hashes=["hash_worker1"],
parent_block_hash="parent_w1",
token_ids=[4, 5],
block_size=16,
lora_id=None,
medium="GPU",
)
worker2_unique_event = BlockStored(
block_hashes=["hash_worker2"],
parent_block_hash="parent_w2",
token_ids=[6, 7],
block_size=16,
lora_id=None,
medium="GPU",
)
worker3_unique_event = BlockStored(
block_hashes=["hash_worker3"],
parent_block_hash="parent_w3",
token_ids=[8, 9],
block_size=16,
lora_id=None,
medium="GPU",
)
# Create events for each worker
# Worker 0: reports common event and its unique event
worker0_events = LMCacheKVEvents(num_workers=1)
worker0_events.add_events([common_event, worker1_unique_event])
# Worker 1: reports common event and its unique event
worker1_events = LMCacheKVEvents(num_workers=1)
worker1_events.add_events([common_event, worker2_unique_event])
# Worker 2: reports common event and its unique event
worker2_events = LMCacheKVEvents(num_workers=1)
worker2_events.add_events([common_event, worker3_unique_event])
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_events in enumerate(
[worker0_events, worker1_events, worker2_events]
):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2
else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0
else None, # Workers 1,2 finished receiving
kv_cache_events=worker_events,
),
)
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
assert isinstance(kv_cache_events, LMCacheKVEvents)
# After aggregation, events should be combined from all workers
# The aggregator doesn't automatically aggregate events, so we need to call
# aggregate() to get only common events
kv_cache_events.aggregate()
aggregated_events = kv_cache_events.get_all_events()
# Only the common event should remain after aggregation
# because it's the only event reported by all 3 workers
assert len(aggregated_events) == 1
assert aggregated_events[0] == common_event
# Verify the common event properties
assert aggregated_events[0].block_hashes == ["hash_common"]
assert aggregated_events[0].parent_block_hash == "parent_common"
assert aggregated_events[0].token_ids == [1, 2, 3]

View File

@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
),
],
)
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
top_logprobs: int,
):
"""Spec decode logprobs should match those of the base model.
@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256

View File

@ -111,7 +111,7 @@ def create_sampling_metadata(
top_p=top_p,
top_k=top_k,
generators=generators,
max_num_logprobs=0,
max_num_logprobs=None,
no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties,

View File

@ -43,6 +43,7 @@ FILES = [
"vllm/worker",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
@ -60,7 +61,6 @@ SEPARATE_GROUPS = [
"vllm/model_executor",
# v1 related
"vllm/v1/attention",
"vllm/v1/executor",
"vllm/v1/kv_offload",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",

View File

@ -2403,6 +2403,29 @@ def cp_gather_cache(
)
def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
workspace_starts: torch.Tensor,
batch_size: int,
) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace.
Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
)
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,

View File

@ -294,6 +294,12 @@ class AttentionImpl(ABC, Generic[T]):
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# Whether the attention impl supports Prefill Context Parallelism.
supports_pcp: bool = False
# Whether the attention impl(or ops) supports MTP
# when cp_kv_cache_interleave_size > 1
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False

View File

@ -355,7 +355,7 @@ def kernel_unified_attention_2d(
@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
# [num_tokens, num_query_heads, num_segments, head_size_padded]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
@ -749,6 +749,11 @@ def unified_attention(
q_descale,
k_descale,
v_descale,
seq_threshold_3D=None,
num_par_softmax_segments=None,
softmax_segm_output=None,
softmax_segm_max=None,
softmax_segm_expsum=None,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
@ -793,8 +798,19 @@ def unified_attention(
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# if batch contains a prefill
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
if (
seq_threshold_3D is None
or num_par_softmax_segments is None
or softmax_segm_output is None
or softmax_segm_max is None
or softmax_segm_expsum is None
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
):
kernel_unified_attention_2d[
(
total_num_q_blocks,
@ -847,37 +863,12 @@ def unified_attention(
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
# value that showed good performance in tests
NUM_SEGMENTS = 16
segm_output = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
triton.next_power_of_2(head_size),
dtype=torch.float32,
device=q.device,
)
segm_max = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
segm_expsum = torch.empty(
q.shape[0],
num_query_heads,
NUM_SEGMENTS,
dtype=torch.float32,
device=q.device,
)
kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)](
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
kernel_unified_attention_3d[
(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
](
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
@ -917,13 +908,13 @@ def unified_attention(
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
)
reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out,
segm_output_ptr=segm_output,
segm_max_ptr=segm_max,
segm_expsum_ptr=segm_expsum,
segm_output_ptr=softmax_segm_output,
segm_max_ptr=softmax_segm_max,
segm_expsum_ptr=softmax_segm_expsum,
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
@ -936,6 +927,6 @@ def unified_attention(
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
USE_FP8=output_scale is not None,
)

View File

@ -141,7 +141,25 @@ class CompilerManager:
# we use ast.literal_eval to parse the data
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
self.cache = ast.literal_eval(f.read())
cache = ast.literal_eval(f.read())
def check_type(value, ty):
if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
def parse_key(key: Any) -> tuple[Range, int, str]:
range_tuple, graph_index, compiler_name = key
check_type(graph_index, int)
check_type(compiler_name, str)
if isinstance(range_tuple, tuple):
start, end = range_tuple
check_type(start, int)
check_type(end, int)
range_tuple = Range(start=start, end=end)
check_type(range_tuple, Range)
return range_tuple, graph_index, compiler_name
self.cache = {parse_key(key): value for key, value in cache.items()}
self.compiler.initialize_cache(
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix

View File

@ -171,22 +171,24 @@ class TorchCompileWithNoGuardsWrapper:
compiled_ptr = self.check_invariants_and_forward
aot_context = nullcontext()
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
else:
msg = "torch._dynamo.config.enable_aot_compile is not "
msg += "available. AOT compile is disabled and please "
msg += "upgrade PyTorch version to use AOT compile."
logger.warning(msg)
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)
with aot_context:
self._compiled_callable = torch.compile(
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,
options=options,
)
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

View File

@ -544,6 +544,11 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
if self.is_encoder_decoder:
self.mm_processor_cache_gb = 0
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
# Init multimodal config if needed
if self._model_info.supports_multimodal:
if (
@ -800,6 +805,13 @@ class ModelConfig:
runner_type: RunnerType,
convert: ConvertOption,
) -> ConvertType:
if convert == "reward":
logger.warning(
"`--convert reward` is deprecated and will be removed in v0.15. "
"Please use `--convert embed` instead."
)
return "embed"
if convert != "auto":
return convert

View File

@ -317,11 +317,6 @@ class ParallelConfig:
"num_redundant_experts."
)
if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self
@property

View File

@ -111,13 +111,15 @@ class PoolerConfig:
def get_use_activation(o: object):
if softmax := getattr(o, "softmax", None) is not None:
logger.warning_once(
"softmax will be deprecated, please use use_activation instead."
"softmax will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return softmax
if activation := getattr(o, "activation", None) is not None:
logger.warning_once(
"activation will be deprecated, please use use_activation instead."
"activation will be deprecated and will be removed in v0.15. "
"Please use use_activation instead."
)
return activation

View File

@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
)
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
have multiple possible names. This is useful when fetching attributes from
arbitrary `transformers.PretrainedConfig` instances.
In the case where the first name in `names` is the preferred name, and
any other names are deprecated aliases, setting `warn=True` will log a
warning when a deprecated name is used.
"""
for name in names:
for i, name in enumerate(names):
if hasattr(object, name):
if warn and i > 0:
logger.warning_once(
"%s contains a deprecated attribute name '%s'. "
"Please use the preferred attribute name '%s' instead.",
type(object).__name__,
name,
names[0],
)
return getattr(object, name)
return default

View File

@ -751,27 +751,17 @@ class VllmConfig:
# TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands
self._set_compile_ranges()
if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
if (
self.model_config
and self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if (
self.kv_events_config is not None
@ -821,11 +811,6 @@ class VllmConfig:
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,

View File

@ -5,7 +5,7 @@ import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from collections import Counter, deque
from collections.abc import Callable
from dataclasses import asdict
from itertools import count
@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
lora_id: int | None
medium: str | None
def __hash__(self) -> int:
return hash(
(
tuple(self.block_hashes),
self.parent_block_hash,
tuple(self.token_ids),
self.block_size,
self.lora_id,
self.medium,
)
)
class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash]
medium: str | None
def __hash__(self) -> int:
return hash((tuple(self.block_hashes), self.medium))
class AllBlocksCleared(KVCacheEvent):
pass
@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
class KVEventAggregator:
"""
Aggregates KV events across multiple workers.
Tracks how many times each event appears and returns only those
that were emitted by all workers.
"""
__slots__ = ("_event_counter", "_num_workers")
def __init__(self, num_workers: int) -> None:
if num_workers <= 0:
raise ValueError("num_workers must be greater than zero.")
self._event_counter: Counter[KVCacheEvent] = Counter()
self._num_workers: int = num_workers
def add_events(self, events: list[KVCacheEvent]) -> None:
"""
Add events from a worker batch.
:param events: List of KVCacheEvent objects.
"""
if not isinstance(events, list):
raise TypeError("events must be a list of KVCacheEvent.")
self._event_counter.update(events)
def get_common_events(self) -> list[KVCacheEvent]:
"""
Return events that appeared in all workers.
:return: List of events present in all workers.
"""
return [
event
for event, count in self._event_counter.items()
if count == self._num_workers
]
def get_all_events(self) -> list[KVCacheEvent]:
"""
Return all events for all workers.
:return: List of events for all workers.
"""
return list(self._event_counter.elements())
def clear_events(self) -> None:
"""
Clear all tracked events.
"""
self._event_counter.clear()
def increment_workers(self, count: int = 1) -> None:
"""
Increment the number of workers contributing events.
:param count: Number to increment the workers by.
"""
if count <= 0:
raise ValueError("count must be positive.")
self._num_workers += count
def reset_workers(self) -> None:
"""
Reset the number of workers to 1.
"""
self._num_workers = 1
def get_number_of_workers(self) -> int:
"""
Return the number of workers.
:return: int number of workers.
"""
return self._num_workers
def __repr__(self) -> str:
return (
f"<KVEventAggregator workers={self._num_workers}, "
f"events={len(self._event_counter)}>"
)
class KVConnectorKVEvents(ABC):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""
@abstractmethod
def add_events(self, events: list[KVCacheEvent]) -> None:
raise NotImplementedError
@abstractmethod
def aggregate(self) -> "KVConnectorKVEvents":
raise NotImplementedError
@abstractmethod
def increment_workers(self, count: int = 1) -> None:
raise NotImplementedError
@abstractmethod
def get_all_events(self) -> list[KVCacheEvent]:
raise NotImplementedError
@abstractmethod
def get_number_of_workers(self) -> int:
raise NotImplementedError
@abstractmethod
def clear_events(self) -> None:
raise NotImplementedError
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism
support.

View File

@ -78,6 +78,7 @@ class KVOutputAggregator:
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
combined_kv_cache_events = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
@ -119,6 +120,19 @@ class KVOutputAggregator:
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
# Combine kv_cache_events from all workers.
if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events = kv_output.kv_cache_events
elif kv_cache_events := kv_output.kv_cache_events:
assert isinstance(
combined_kv_cache_events,
type(kv_cache_events),
)
worker_kv_cache_events = kv_cache_events.get_all_events()
combined_kv_cache_events.add_events(worker_kv_cache_events)
combined_kv_cache_events.increment_workers(1)
invalid_block_ids |= kv_output.invalid_block_ids
# select output of the worker specified by output_rank
@ -129,6 +143,7 @@ class KVOutputAggregator:
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
)

View File

@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
"""
return None
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return None
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.

View File

@ -1,14 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
import torch
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import (
BlockStored,
KVCacheEvent,
KVConnectorKVEvents,
KVEventAggregator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
@ -26,6 +31,44 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class LMCacheKVEvents(KVConnectorKVEvents):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""
def __init__(self, num_workers: int) -> None:
self._aggregator = KVEventAggregator(num_workers)
def add_events(self, events: list[KVCacheEvent]) -> None:
self._aggregator.add_events(events)
def aggregate(self) -> "LMCacheKVEvents":
"""
Aggregate KV events and retain only common events.
"""
common_events = self._aggregator.get_common_events()
self._aggregator.clear_events()
self._aggregator.add_events(common_events)
self._aggregator.reset_workers()
return self
def increment_workers(self, count: int = 1) -> None:
self._aggregator.increment_workers(count)
def get_all_events(self) -> list[KVCacheEvent]:
return self._aggregator.get_all_events()
def get_number_of_workers(self) -> int:
return self._aggregator.get_number_of_workers()
def clear_events(self) -> None:
self._aggregator.clear_events()
self._aggregator.reset_workers()
def __repr__(self) -> str:
return f"<LMCacheKVEvents events={self.get_all_events()}>"
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(
self,
@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
cls = _adapter.LMCacheConnectorV1Impl
else:
logger.info("Initializing latest dev LMCache connector")
# lazy import
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
cls = LMCacheConnectorLatestImpl
self._lmcache_engine = cls(vllm_config, role, self)
self._kv_cache_events: LMCacheKVEvents | None = None
# ==============================
# Worker-side methods
# ==============================
@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# Fallback for older versions that don't support this method
return set()
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
if not events:
return None
blocks: list[BlockStored] = [
BlockStored(
block_hashes=e.block_hashes,
parent_block_hash=e.parent_block_hash,
token_ids=e.token_ids,
lora_id=e.lora_id,
block_size=e.block_size,
medium=e.medium,
)
for e in events
]
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
lmcache_kv_events.add_events(blocks)
return lmcache_kv_events
# ==============================
# Scheduler-side methods
# ==============================
@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
return
if self._kv_cache_events is None:
self._kv_cache_events = kv_cache_events
else:
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
self._kv_cache_events.increment_workers(
kv_cache_events.get_number_of_workers()
)
return
def request_finished(
self,
request: "Request",
@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
returned by the engine.
"""
return self._lmcache_engine.request_finished(request, block_ids)
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if self._kv_cache_events is not None:
self._kv_cache_events.aggregate()
kv_cache_events = self._kv_cache_events.get_all_events()
yield from kv_cache_events
self._kv_cache_events.clear_events()
self._kv_cache_events = None

View File

@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
LMCacheAsyncLookupServer,
)
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
try:
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
except ImportError:
# Backwards compatibility for lmcache <= 0.3.10-post1
from lmcache.v1.plugin.plugin_launcher import (
PluginLauncher as RuntimePluginLauncher,
)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig

View File

@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids |= c.get_block_ids_with_load_errors()
return agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
# ==============================
# Scheduler-side methods
# ==============================

View File

@ -1649,7 +1649,13 @@ class EngineArgs:
"attention_backend and attention_config.backend "
"are mutually exclusive"
)
attention_config.backend = self.attention_backend
# Convert string to enum if needed (CLI parsing returns a string)
if isinstance(self.attention_backend, str):
attention_config.backend = AttentionBackendEnum[
self.attention_backend.upper()
]
else:
attention_config.backend = self.attention_backend
load_config = self.create_load_config()
@ -1783,6 +1789,7 @@ class EngineArgs:
except Exception:
# This is only used to set default_max_num_batched_tokens
device_memory = 0
device_name = ""
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
@ -1847,16 +1854,6 @@ class EngineArgs:
default_chunked_prefill = model_config.is_chunked_prefill_supported
default_prefix_caching = model_config.is_prefix_caching_supported
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
default_prefix_caching = False
logger.warning_once(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default.",
scope="local",
)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
@ -2042,11 +2039,13 @@ def human_readable_int(value):
"k": 10**3,
"m": 10**6,
"g": 10**9,
"t": 10**12,
}
binary_multiplier = {
"K": 2**10,
"M": 2**20,
"G": 2**30,
"T": 2**40,
}
number, suffix = match.groups()

View File

@ -324,12 +324,12 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id,
content=[],
model=origin_chunk.model,
),
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
else 0,
output_tokens=0,
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
else 0,
output_tokens=0,
),
),
)
first_item = False

View File

@ -176,7 +176,7 @@ class FrontendArgs:
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
"""Enable the /get_tokenizer_info endpoint. May expose chat
"""Enable the `/tokenizer_info` endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If True, log model outputs (generations).

View File

@ -232,7 +232,177 @@ def parse_response_input(
return msg
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs: list[Message] = []
tool_id_names: dict[str, str] = {}
# Collect tool id to name mappings for tool response recipient values
for chat_msg in chat_msgs:
for tool_call in chat_msg.get("tool_calls", []):
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
"name"
)
for chat_msg in chat_msgs:
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
msgs = auto_drop_analysis_messages(msgs)
return msgs
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index = -1
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if msg.author.role == "assistant" and msg.channel == "final":
last_assistant_final_index = i
break
cleaned_msgs: list[Message] = []
for i, msg in enumerate(msgs):
if i < last_assistant_final_index and msg.channel == "analysis":
continue
cleaned_msgs.append(msg)
return cleaned_msgs
def flatten_chat_text_content(content: str | list | None) -> str | None:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if isinstance(content, list):
return "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
return content
def parse_chat_input_to_harmony_message(
chat_msg, tool_id_names: dict[str, str] | None = None
) -> list[Message]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names = tool_id_names or {}
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
msgs: list[Message] = []
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls", [])
if role == "assistant" and tool_calls:
content = flatten_chat_text_content(chat_msg.get("content"))
if content:
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
commentary_msg = commentary_msg.with_channel("commentary")
msgs.append(commentary_msg)
reasoning_content = chat_msg.get("reasoning") or chat_msg.get(
"reasoning_content"
)
if reasoning_content:
analysis_msg = Message.from_role_and_content(
Role.ASSISTANT, reasoning_content
)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
tool_call_id = chat_msg.get("tool_call_id", "")
name = tool_id_names.get(tool_call_id, "")
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
msg = (
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg]
# Non-tool reasoning content
reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content")
if role == "assistant" and reasoning_content:
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
# Default: user/assistant/system messages with content
content = chat_msg.get("content") or ""
if content is None:
content = ""
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if role == "assistant" and contents and contents[0].text:
msg = Message.from_role_and_contents(role, contents)
# Send non-tool assistant messages to the final channel
msg = msg.with_channel("final")
msgs.append(msg)
# For user/system/developer messages, add them directly even if no content.
elif role != "assistant":
msg = Message.from_role_and_contents(role, contents)
msgs.append(msg)
return msgs
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
"""
Parse a message from request.previous_input_messages in the Responsees API to
Harmony messages.
"""
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
if isinstance(content, list):
# Handle array format for tool message content
# by concatenating all text parts.
content = "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
content = flatten_chat_text_content(content)
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def parse_chat_output(
token_ids: Sequence[int],
) -> tuple[str | None, str | None, bool]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
reasoning = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
reasoning = output_msgs[0].content[0].text
final_content = parser.current_content
else:
reasoning_msg = output_msgs[:-1]
final_msg = output_msgs[-1]
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
final_content = final_msg.content[0].text
# Get completed messages from the parser
reasoning_texts = [
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
]
final_texts = [
msg.content[0].text for msg in output_msgs if msg.channel != "analysis"
]
# Extract partial messages from the parser
if parser.current_channel == "analysis" and parser.current_content:
reasoning_texts.append(parser.current_content)
elif parser.current_channel != "analysis" and parser.current_content:
final_texts.append(parser.current_content)
# Flatten multiple messages into a single string
reasoning: str | None = "\n".join(reasoning_texts)
final_content: str | None = "\n".join(final_texts)
# Return None instead of empty string since existing callers check for None
reasoning = reasoning or None
final_content = final_content or None
return reasoning, final_content, is_tool_call

View File

@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_inputs_to_harmony_messages,
parse_chat_output,
parse_input_to_harmony_message,
render_for_completion,
)
from vllm.entrypoints.openai.protocol import (
@ -822,6 +822,9 @@ class OpenAIServingChat(OpenAIServing):
if delta_message is not None:
harmony_tools_streamed[i] = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice
@ -1770,6 +1773,11 @@ class OpenAIServingChat(OpenAIServing):
):
messages: list[OpenAIMessage] = []
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
# if the model supports it. TODO: Support browsing.
@ -1788,8 +1796,7 @@ class OpenAIServingChat(OpenAIServing):
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.extend(parse_input_to_harmony_message(chat_msg))
messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)

View File

@ -43,6 +43,7 @@ class OpenAIToolParser(ToolParser):
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
commentary_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
@ -75,11 +76,15 @@ class OpenAIToolParser(ToolParser):
)
elif msg.channel == "final":
final_content = msg_text
elif msg.channel == "commentary" and not msg.recipient:
commentary_content = msg_text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=final_content,
# prefer final content over commentary content if both are present
# commentary content is tool call preambles meant to be shown to the user
content=final_content or commentary_content,
)
def extract_tool_calls_streaming(

View File

@ -239,6 +239,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
@ -1537,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))

View File

@ -229,6 +229,11 @@ def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
# guaranteed by the Python GIL.
_configure_vllm_root_logger()
# Transformers uses httpx to access the Hugging Face Hub. httpx is quite verbose,
# so we set its logging level to WARNING when vLLM's logging level is INFO.
if envs.VLLM_LOGGING_LEVEL == "INFO":
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = init_logger(__name__)

Some files were not shown because too many files have changed in this diff Show More