mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 08:16:58 +08:00
Merge branch 'main' into imarkov/eplb_optimizations
This commit is contained in:
commit
11c492ae81
@ -15,6 +15,21 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build arm64 wheel - CUDA 13.0"
|
||||
depends_on: ~
|
||||
id: build-wheel-arm64-cuda-13-0
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# aarch64 build
|
||||
- label: "Build arm64 CPU wheel"
|
||||
depends_on: ~
|
||||
@ -25,7 +40,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -39,7 +54,7 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -52,7 +67,21 @@ steps:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# x86 CPU wheel build
|
||||
- label: "Build x86 CPU wheel"
|
||||
depends_on: ~
|
||||
id: build-wheel-x86-cpu
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
|
||||
@ -372,6 +372,17 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}")
|
||||
|
||||
# keep only "official" files for a non-nightly version (specifed by cli args)
|
||||
PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$")
|
||||
if PY_VERSION_RE.match(version):
|
||||
# upload-wheels.sh ensures no "dev" is in args.version
|
||||
wheel_files = list(
|
||||
filter(lambda x: version in x and "dev" not in x, wheel_files)
|
||||
)
|
||||
print(f"Non-nightly version detected, wheel files used: {wheel_files}")
|
||||
else:
|
||||
print("Nightly version detected, keeping all wheel files.")
|
||||
|
||||
# Generate index and metadata, assuming wheels and indices are stored as:
|
||||
# s3://vllm-wheels/{version}/<wheel files>
|
||||
# s3://vllm-wheels/<anything>/<index files>
|
||||
|
||||
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
||||
THRESHOLD=${1:-0.25}
|
||||
NUM_Q=${2:-1319}
|
||||
PORT=${3:-8040}
|
||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
||||
mkdir -p "${OUT_DIR}"
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 600 bash -c '
|
||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
||||
sleep 1
|
||||
done'
|
||||
}
|
||||
|
||||
MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||
|
||||
# Set BACKENDS based on platform
|
||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
||||
# ROCm platform
|
||||
BACKENDS=("allgather_reducescatter")
|
||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
||||
export VLLM_ROCM_MOE_PADDING=0
|
||||
else
|
||||
# Non-ROCm platform (CUDA/other)
|
||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
||||
fi
|
||||
|
||||
cleanup() {
|
||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
||||
kill "${SERVER_PID}" 2>/dev/null || true
|
||||
for _ in {1..20}; do
|
||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
||||
sleep 0.5
|
||||
done
|
||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for BACK in "${BACKENDS[@]}"; do
|
||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
||||
vllm serve "$MODEL" \
|
||||
--enforce-eager \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--enable-eplb \
|
||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
|
||||
--trust-remote-code \
|
||||
--max-model-len 2048 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--port $PORT &
|
||||
SERVER_PID=$!
|
||||
wait_for_server $PORT
|
||||
|
||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
||||
OUT="${OUT_DIR}/${TAG}_${BACK}.json"
|
||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
||||
python3 - <<PY
|
||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
||||
PY
|
||||
|
||||
cleanup
|
||||
SERVER_PID=
|
||||
sleep 1
|
||||
PORT=$((PORT+1))
|
||||
done
|
||||
@ -34,9 +34,10 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then
|
||||
fi
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# current build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31
|
||||
# we also accept params as manylinux tag
|
||||
# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels
|
||||
manylinux_version="manylinux_2_31"
|
||||
manylinux_version="${1:-manylinux_2_31}"
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
if [[ "$wheel" != *"linux"* ]]; then
|
||||
@ -96,8 +97,11 @@ if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]];
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/"
|
||||
fi
|
||||
|
||||
# copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
# re-generate and copy to /<pure_version>/ only if it does not have "dev" in the version
|
||||
if [[ "$version" != *"dev"* ]]; then
|
||||
echo "Uploading indices to overwrite /$pure_version/"
|
||||
echo "Re-generating indices for /$pure_version/"
|
||||
rm -rf "$INDICES_OUTPUT_DIR/*"
|
||||
mkdir -p "$INDICES_OUTPUT_DIR"
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg
|
||||
aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/"
|
||||
fi
|
||||
|
||||
@ -326,10 +326,10 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
|
||||
|
||||
- label: V1 Test e2e + engine # 30min
|
||||
timeout_in_minutes: 45
|
||||
- label: V1 Test e2e + engine # 65min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -435,7 +435,7 @@ steps:
|
||||
|
||||
- label: Examples Test # 30min
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -455,7 +455,6 @@ steps:
|
||||
# for multi-modal models
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
# for pooling models
|
||||
@ -1630,7 +1629,6 @@ steps:
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_4
|
||||
# grade: Blocking
|
||||
gpu: h100
|
||||
optional: true
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace"
|
||||
|
||||
@ -692,6 +692,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Run a subset of model initialization tests
|
||||
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
|
||||
@ -704,6 +705,7 @@ steps:
|
||||
- vllm/model_executor/models/
|
||||
- vllm/transformers_utils/
|
||||
- tests/models/test_initialization.py
|
||||
- tests/models/registry.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
# subset of supported models (the complement of the small subset in the above
|
||||
@ -836,7 +838,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
@ -1346,6 +1348,7 @@ steps:
|
||||
- label: Prime-RL Integration Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
soft_fail: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
@ -1379,4 +1382,4 @@ steps:
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||
|
||||
@ -384,7 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
PYTHONPATH=$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE marlin_generation_result
|
||||
OUTPUT_VARIABLE marlin_generation_result
|
||||
@ -822,7 +822,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
@ -1004,7 +1004,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=$PYTHONPATH
|
||||
PYTHONPATH=$ENV{PYTHONPATH}
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
|
||||
@ -143,11 +143,13 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- IBM
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Red Hat
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
|
||||
@ -620,7 +620,7 @@ def get_tokenizer(
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MistralTokenizer requires vllm package.\n"
|
||||
|
||||
@ -99,7 +99,6 @@ def benchmark_mrope(
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
|
||||
@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
base = 10000
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
|
||||
@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
|
||||
run_python(_VLLM_TORCH_GOMP_PATH
|
||||
"
|
||||
import os, glob
|
||||
try:
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
torch_libs = os.path.join(site_root, 'torch.libs')
|
||||
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
|
||||
except:
|
||||
print('')
|
||||
import torch
|
||||
torch_pkg = os.path.dirname(torch.__file__)
|
||||
site_root = os.path.dirname(torch_pkg)
|
||||
|
||||
# Search both torch.libs and torch/lib
|
||||
roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')]
|
||||
candidates = []
|
||||
for root in roots:
|
||||
if not os.path.isdir(root):
|
||||
continue
|
||||
candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*')))
|
||||
|
||||
print(candidates[0] if candidates else '')
|
||||
"
|
||||
"failed to probe torch.libs for libgomp")
|
||||
"failed to probe for libgomp")
|
||||
|
||||
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
|
||||
return()
|
||||
|
||||
12
csrc/cache.h
12
csrc/cache.h
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@ -58,6 +59,15 @@ void cp_gather_cache(
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Gather and upconvert FP8 KV cache to BF16 workspace
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache(
|
||||
torch::Tensor& dst_k, // [num_tokens, head_dim]
|
||||
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
|
||||
const torch::Tensor& block_table, // [batch_size, num_blocks]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache(
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
|
||||
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
|
||||
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ seq_lens, // [BATCH]
|
||||
const int32_t* __restrict__ workspace_starts, // [BATCH]
|
||||
const int32_t block_size, const int32_t head_dim,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = workspace_starts[bid];
|
||||
const int32_t seq_len = seq_lens[bid];
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Process each token in this split
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
const uint8_t* token_ptr =
|
||||
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
|
||||
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
|
||||
|
||||
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
|
||||
const uint8_t* no_pe_ptr = token_ptr;
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const __nv_bfloat16* rope_ptr =
|
||||
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
|
||||
|
||||
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
|
||||
if (tid < 512) {
|
||||
// FP8 dequantization
|
||||
const int tile = tid >> 7; // each tile is 128 elements
|
||||
const float scale = scales_ptr[tile];
|
||||
const uint8_t val = no_pe_ptr[tid];
|
||||
dst_ptr[tid] =
|
||||
fp8::scaled_convert<__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
|
||||
} else if (tid < 576) {
|
||||
// Rope copy (64 bf16 elements)
|
||||
const int rope_idx = tid - 512;
|
||||
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
|
||||
}
|
||||
|
||||
// Move to next token
|
||||
offset += 1;
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
|
||||
// block_size.
|
||||
@ -1202,6 +1280,57 @@ void cp_gather_cache(
|
||||
}
|
||||
}
|
||||
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t head_dim = dst.size(1);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
|
||||
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
|
||||
"workspace_starts must be int32");
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == seq_lens.device(),
|
||||
"src_cache and seq_lens must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
||||
"src_cache and workspace_starts must be on the same device");
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
|
||||
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
||||
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
src_cache.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
dst_entry_stride);
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
|
||||
@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
|
||||
@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
}
|
||||
@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, bool SCALE_UE8M0>
|
||||
__device__ __forceinline__ float ComputeGroupScale(
|
||||
const T* __restrict__ group_input, T* __restrict__ smem_group,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float eps, const float max_8bit) {
|
||||
float local_absmax = eps;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
return y_s;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__device__ __forceinline__ void QuantizeGroup(
|
||||
const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float y_s, const float min_8bit, const float max_8bit) {
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = float;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
const float y_s = ComputeGroupScale<T, SCALE_UE8M0>(
|
||||
group_input, smem_group, group_size, lane_id, threads_per_group, eps,
|
||||
max_8bit);
|
||||
|
||||
scale_element_t y_s_quant = y_s;
|
||||
|
||||
@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
inline int GetGroupsPerBlock(int64_t num_groups) {
|
||||
if (num_groups % 16 == 0) {
|
||||
return 16;
|
||||
}
|
||||
if (num_groups % 8 == 0) {
|
||||
return 8;
|
||||
}
|
||||
if (num_groups % 4 == 0) {
|
||||
return 4;
|
||||
}
|
||||
if (num_groups % 2 == 0) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
@ -225,8 +253,6 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
@ -235,29 +261,9 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
const float y_s =
|
||||
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
|
||||
threads_per_group, eps, max_8bit);
|
||||
|
||||
// pack 4 scales into a uint32
|
||||
if (lane_id == 0) {
|
||||
@ -284,19 +290,8 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
@ -337,17 +332,7 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
|
||||
@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
|
||||
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
|
||||
"batch_size) -> ()");
|
||||
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
|
||||
&cp_gather_and_upconvert_fp8_kv_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
|
||||
@ -76,6 +76,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
ENV NIXL_VERSION=0.7.0
|
||||
RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py
|
||||
|
||||
# PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts
|
||||
RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf
|
||||
|
||||
# remove torch bundled oneccl to avoid conflicts
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip uninstall oneccl oneccl-devel -y
|
||||
|
||||
@ -24,11 +24,13 @@ Compute Resources:
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Google Cloud
|
||||
- IBM
|
||||
- Intel
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Red Hat
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
|
||||
@ -7,7 +7,7 @@ This guide covers optimization strategies and performance tuning for vLLM V1.
|
||||
|
||||
## Preemption
|
||||
|
||||
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||
Due to the autoregressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||
In such cases, vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes
|
||||
available again. When this occurs, you may see the following warning:
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ DOCKER_BUILDKIT=1 docker build . \
|
||||
|
||||
## Building for Arm64/aarch64
|
||||
|
||||
A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64.
|
||||
A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper and Grace-Blackwell. Using the flag `--platform "linux/arm64"` will build for arm64.
|
||||
|
||||
!!! note
|
||||
Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=`
|
||||
@ -104,6 +104,25 @@ A docker container can be built for aarch64 systems such as the Nvidia Grace-Hop
|
||||
--build-arg RUN_WHEEL_CHECK=false
|
||||
```
|
||||
|
||||
For (G)B300, we recommend using CUDA 13, as shown in the following command.
|
||||
|
||||
??? console "Command"
|
||||
|
||||
```bash
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg CUDA_VERSION=13.0.1 \
|
||||
--build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 \
|
||||
--build-arg max_jobs=256 \
|
||||
--build-arg nvcc_threads=2 \
|
||||
--build-arg RUN_WHEEL_CHECK=false \
|
||||
--build-arg torch_cuda_arch_list='9.0 10.0+PTX' \
|
||||
--platform "linux/arm64" \
|
||||
--tag vllm/vllm-gb300-openai:latest \
|
||||
--target vllm-openai \
|
||||
-f docker/Dockerfile \
|
||||
.
|
||||
```
|
||||
|
||||
!!! note
|
||||
If you are building the `linux/arm64` image on a non-ARM host (e.g., an x86_64 machine), you need to ensure your system is set up for cross-compilation using QEMU. This allows your host machine to emulate ARM64 execution.
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le
|
||||
|
||||
* **Upstream vLLM compatibility** – It wraps around upstream vLLM without modifying its code.
|
||||
* **Ease of use** – Simplified deployment via Helm charts and observability through Grafana dashboards.
|
||||
* **High performance** – Optimized for LLM workloads with features like multi-model support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others.
|
||||
* **High performance** – Optimized for LLM workloads with features like multimodel support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others.
|
||||
|
||||
If you are new to Kubernetes, don't worry: in the vLLM production stack [repo](https://github.com/vllm-project/production-stack), we provide a step-by-step [guide](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) and a [short video](https://www.youtube.com/watch?v=EsTJbQtzj0g) to set up everything and get started in **4 minutes**!
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ These features allow the most flexibility for cudagraph capture and compilation
|
||||
* `NONE` — turn CUDA Graphs off. Good for debugging.
|
||||
* `PIECEWISE` — a single-mode strategy (and past default). It is the most flexible: attention or other CUDA Graphs-incompatible operations stay eager, everything else goes into CUDA Graphs. Requires piecewise compilation.
|
||||
* `FULL` — a single-mode strategy, which only captures full CUDA Graphs for non-uniform batches, then uniform-decode batches reuse the CUDA Graph of non-uniform batch of the same batch_size, since they are compatible; can be good for small models or workloads with small prompts.
|
||||
* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs.
|
||||
* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc.; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs.
|
||||
* `FULL_AND_PIECEWISE` — (default mode) full CUDA Graph for uniform decode, piecewise CUDA Graphs for others; generally the most performant setting, especially for low latency with small models or MoEs, but also requires the most memory and takes the longest to capture.
|
||||
|
||||
Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_PIECEWISE` for better performance, (for pooling models, it's still `PIECEWISE`). Otherwise, e.g. if piecewise compilation unavailable, we default to `NONE`.
|
||||
@ -49,7 +49,7 @@ Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_
|
||||
While `NONE` , `PIECEWISE`, and `FULL` are single-mode configurations and simply equivalent to past implementations of eager execution, piecewise CUDA Graphs, and full CUDA Graphs respectively, `FULL_DECODE_ONLY` and `FULL_AND_PIECEWISE` are newly appended dual-mode configurations, which require dispatching to switch between concrete runtime modes according to runtime batches dynamically.
|
||||
|
||||
!!! note
|
||||
Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potantial `NONE` if no suitable CUDA Graph available), depending on the batch composition.
|
||||
Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potential `NONE` if no suitable CUDA Graph available), depending on the batch composition.
|
||||
|
||||
While cascade attention is not cudagraph compatible, it is now compatible with all possible cudagraph mode configurations. If a batch uses cascade attention, it always gets dispatched to `PIECEWISE` mode if available (otherwise `NONE`).
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
## Overview
|
||||
|
||||
vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
|
||||
vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechanism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out-of-the-box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
|
||||
|
||||
## Level Summaries and Usage Examples
|
||||
```bash
|
||||
|
||||
@ -36,7 +36,7 @@ the input pointers `q`, `k_cache`, and `v_cache`, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer `out` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
refer to multidimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
@ -229,7 +229,7 @@ manner.
|
||||
|
||||
## QK
|
||||
|
||||
As shown the pseudo code below, before the entire for loop block, we
|
||||
As shown the pseudocode below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in `q_vecs`. Then,
|
||||
in the outer for loop, we iterate through different `k_ptrs` that
|
||||
point to different tokens and prepare the `k_vecs` in the inner for
|
||||
@ -403,7 +403,7 @@ for ... { // Iteration over different blocks.
|
||||
}
|
||||
```
|
||||
|
||||
As shown in the above pseudo code, in the outer loop, similar to
|
||||
As shown in the above pseudocode, in the outer loop, similar to
|
||||
`k_ptr`, `logits_vec` iterates over different blocks and reads
|
||||
`V_VEC_SIZE` elements from `logits`. In the inner loop, each
|
||||
thread reads `V_VEC_SIZE` elements from the same tokens as a
|
||||
|
||||
@ -29,8 +29,27 @@ uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.a
|
||||
|
||||
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
|
||||
|
||||
!!! note
|
||||
Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression).
|
||||
**Install the latest code**
|
||||
|
||||
LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides working pre-built Arm CPU wheels for every commit since `v0.11.2` on <https://wheels.vllm.ai/nightly>. For native CPU wheels, this index should be used:
|
||||
|
||||
* `https://wheels.vllm.ai/nightly/cpu/vllm`
|
||||
|
||||
To install from nightly index, copy the link address of the `*.whl` under this index to run, for example:
|
||||
|
||||
```bash
|
||||
uv pip install -U https://wheels.vllm.ai/c756fb678184b867ed94e5613a529198f1aee423/vllm-0.13.0rc2.dev11%2Bgc756fb678.cpu-cp38-abi3-manylinux_2_31_aarch64.whl # current nightly build (the filename will change!)
|
||||
```
|
||||
|
||||
**Install specific revisions**
|
||||
|
||||
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), specify the full commit hash in the index:
|
||||
https://wheels.vllm.ai/${VLLM_COMMIT}/cpu/vllm .
|
||||
Then, copy the link address of the `*.whl` under this index to run:
|
||||
|
||||
```bash
|
||||
uv pip install -U <wheel-url>
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
@ -81,7 +100,23 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm CPU images.
|
||||
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
|
||||
|
||||
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
|
||||
Please replace `<version>` in the command below with a specific version string (e.g., `0.12.0`).
|
||||
|
||||
```bash
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v<version>
|
||||
```
|
||||
|
||||
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.
|
||||
|
||||
The latest code can contain bugs and may not be stable. Please use it with caution.
|
||||
|
||||
```bash
|
||||
export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu
|
||||
```
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
@ -281,17 +281,27 @@ Alternatively, you can use the `openai` Python package:
|
||||
|
||||
Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications.
|
||||
|
||||
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options:
|
||||
If desired, you can also manually set the backend of your choice using the `--attention-backend` CLI argument:
|
||||
|
||||
```bash
|
||||
# For online serving
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --attention-backend FLASH_ATTN
|
||||
|
||||
# For offline inference
|
||||
python script.py --attention-backend FLASHINFER
|
||||
```
|
||||
|
||||
Some of the available backend options include:
|
||||
|
||||
- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`.
|
||||
- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`.
|
||||
|
||||
For AMD ROCm, you can further control the specific Attention implementation using the following variables:
|
||||
For AMD ROCm, you can further control the specific Attention implementation using the following options:
|
||||
|
||||
- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- Triton Prefill-Decode Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0`
|
||||
- AITER Multi-head Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=1`
|
||||
- Triton Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=0 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
- AITER Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
- Triton Prefill-Decode Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=true` as a CLI argument.
|
||||
- AITER Multi-head Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=1` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [docker/Dockerfile](../../docker/Dockerfile) for instructions on how to install it.
|
||||
|
||||
@ -568,7 +568,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
```
|
||||
|
||||
!!! note
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker.py](../../examples/pooling/score/qwen3_reranker.py).
|
||||
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py).
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
|
||||
@ -659,7 +659,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
|
||||
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
|
||||
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
|
||||
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
|
||||
@ -743,7 +745,7 @@ Some models are supported only via the [Transformers modeling backend](#transfor
|
||||
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
|
||||
|
||||
!!! note
|
||||
For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently.
|
||||
For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc.), InternVL3 and InternVL3.5 have video inputs support currently.
|
||||
|
||||
!!! note
|
||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
|
||||
@ -8,11 +8,11 @@ For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Lat
|
||||
|
||||
In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks.
|
||||
|
||||
The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case).
|
||||
By default, expert layers form a tensor parallel group of size `DP × TP`. To use expert parallelism instead, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). See [Expert Parallel Deployment](expert_parallel_deployment.md) for details on how attention and expert layers behave differently with EP enabled.
|
||||
|
||||
In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size.
|
||||
|
||||
For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP).
|
||||
For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form a group of size `DP × TP` (using either tensor parallelism by default, or expert parallelism if `--enable-expert-parallel` is set).
|
||||
|
||||
In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently.
|
||||
|
||||
|
||||
@ -44,7 +44,27 @@ Where:
|
||||
- `DP_SIZE`: Data parallel size
|
||||
- `EP_SIZE`: Expert parallel size (computed automatically)
|
||||
|
||||
When EP is enabled, MoE layers use expert parallelism instead of tensor parallelism, while attention layers continue to use tensor parallelism if `TP_SIZE > 1`.
|
||||
### Layer Behavior with EP Enabled
|
||||
|
||||
When EP is enabled, different layers in MoE models behave differently:
|
||||
|
||||
| Layer Type | Behavior | Parallelism Used |
|
||||
|------------|----------|------------------|
|
||||
| **Expert (MoE) Layers** | Sharded across all EP ranks | Expert Parallel (EP) of size `TP × DP` |
|
||||
| **Attention Layers** | Behavior depends on TP size | See below |
|
||||
|
||||
**Attention layer parallelism:**
|
||||
|
||||
- **When `TP = 1`**: Attention weights are **replicated** across all DP ranks (data parallelism)
|
||||
- **When `TP > 1`**: Attention weights are **sharded** using tensor parallelism across TP ranks within each DP group
|
||||
|
||||
For example, with `TP=2, DP=4` (8 GPUs total):
|
||||
|
||||
- Expert layers form an EP group of size 8, with experts distributed across all GPUs
|
||||
- Attention layers use TP=2 within each of the 4 DP groups
|
||||
|
||||
!!! note "Key Difference from Data Parallel Deployment"
|
||||
Without `--enable-expert-parallel`, MoE layers would use tensor parallelism (forming a TP group of size `TP × DP`), similar to dense models. With EP enabled, expert layers switch to expert parallelism, which can provide better efficiency and locality for MoE models.
|
||||
|
||||
### Example Command
|
||||
|
||||
|
||||
@ -851,7 +851,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin
|
||||
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
|
||||
popular open-source tools.
|
||||
|
||||
Code example: [examples/pooling/score/jinaai_rerank_client.py](../../examples/pooling/score/jinaai_rerank_client.py)
|
||||
Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py)
|
||||
|
||||
#### Example Request
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ If a single node lacks sufficient GPUs to hold the model, deploy vLLM across mul
|
||||
|
||||
### What is Ray?
|
||||
|
||||
Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments require Ray as the runtime engine.
|
||||
Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments can use Ray as the runtime engine.
|
||||
|
||||
vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens.
|
||||
|
||||
@ -130,9 +130,31 @@ vllm serve /path/to/the/model/in/the/container \
|
||||
--distributed-executor-backend ray
|
||||
```
|
||||
|
||||
### Running vLLM with MultiProcessing
|
||||
|
||||
Besides Ray, Multi-node vLLM deployments can also use `multiprocessing` as the runtime engine. Here's an example to deploy model across 2 nodes (8 GPUs per node) with `tp_size=8` and `pp_size=2`.
|
||||
|
||||
Choose one node as the head node and run:
|
||||
|
||||
```bash
|
||||
vllm serve /path/to/the/model/in/the/container \
|
||||
--tensor-parallel-size 8 --pipeline-parallel-size 2 \
|
||||
--nnodes 2 --node-rank 0 \
|
||||
--master-addr <HEAD_NODE_IP>
|
||||
```
|
||||
|
||||
On the other worker node, run:
|
||||
|
||||
```bash
|
||||
vllm serve /path/to/the/model/in/the/container \
|
||||
--tensor-parallel-size 8 --pipeline-parallel-size 2 \
|
||||
--nnodes 2 --node-rank 1 \
|
||||
--master-addr <HEAD_NODE_IP> --headless
|
||||
```
|
||||
|
||||
## Optimizing network communication for tensor parallelism
|
||||
|
||||
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand.
|
||||
Efficient tensor parallelism requires fast internode communication, preferably through high-speed network adapters such as InfiniBand.
|
||||
To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the
|
||||
[examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) helper script.
|
||||
Contact your system administrator for more information about the required flags.
|
||||
|
||||
@ -10,7 +10,7 @@ All communications between nodes in a multi-node vLLM deployment are **insecure
|
||||
|
||||
### Configuration Options for Inter-Node Communications
|
||||
|
||||
The following options control inter-node communications in vLLM:
|
||||
The following options control internode communications in vLLM:
|
||||
|
||||
#### 1. **Environment Variables:**
|
||||
|
||||
@ -28,7 +28,7 @@ The following options control inter-node communications in vLLM:
|
||||
|
||||
### Notes on PyTorch Distributed
|
||||
|
||||
vLLM uses PyTorch's distributed features for some inter-node communication. For
|
||||
vLLM uses PyTorch's distributed features for some internode communication. For
|
||||
detailed information about PyTorch Distributed security considerations, please
|
||||
refer to the [PyTorch Security
|
||||
Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features).
|
||||
|
||||
@ -42,60 +42,31 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
# AudioFlamingo3
|
||||
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "nvidia/audio-flamingo-3-hf"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
# AudioFlamingo3 uses <sound> token for audio
|
||||
audio_placeholder = "<sound>" * audio_count
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
prompt = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_placeholder}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
@ -361,6 +332,63 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
assert audio_count == 1, "Whisper only support single audio input per prompt"
|
||||
@ -382,7 +410,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"audioflamingo3": run_audioflamingo3,
|
||||
"gemma3n": run_gemma3n,
|
||||
"granite_speech": run_granite_speech,
|
||||
"midashenglm": run_midashenglm,
|
||||
@ -392,6 +420,7 @@ model_example_map = {
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"ultravox": run_ultravox,
|
||||
"voxtral": run_voxtral,
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -20,6 +23,11 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
|
||||
@ -33,6 +33,7 @@ import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
@ -222,6 +223,11 @@ if __name__ == "__main__":
|
||||
|
||||
from multiprocessing import Process
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from multiprocessing import set_start_method
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
|
||||
@ -118,6 +118,32 @@ def run_bee(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_bagel(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "ByteDance-Seed/BAGEL-7B-MoT"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
(
|
||||
f"<|im_start|>user\n<|image_pad|>\n{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1832,6 +1858,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": run_aria,
|
||||
"aya_vision": run_aya_vision,
|
||||
"bagel": run_bagel,
|
||||
"bee": run_bee,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# --worker \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<worker_node_ip>
|
||||
#
|
||||
#
|
||||
# Each worker requires a unique VLLM_HOST_IP value.
|
||||
# Keep each terminal session open. Closing a session stops the associated Ray
|
||||
# node and thereby shuts down the entire cluster.
|
||||
@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=...").
|
||||
VLLM_HOST_IP=""
|
||||
for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do
|
||||
arg="${ADDITIONAL_ARGS[$i]}"
|
||||
case "${arg}" in
|
||||
-e)
|
||||
next="${ADDITIONAL_ARGS[$((i + 1))]:-}"
|
||||
if [[ "${next}" == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${next#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
;;
|
||||
-eVLLM_HOST_IP=* | VLLM_HOST_IP=*)
|
||||
VLLM_HOST_IP="${arg#*=}"
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent.
|
||||
if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then
|
||||
if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then
|
||||
echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})."
|
||||
echo "Using VLLM_HOST_IP as the head node address."
|
||||
HEAD_NODE_ADDRESS="${VLLM_HOST_IP}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Generate a unique container name with random suffix.
|
||||
# Docker container names must be unique on each host.
|
||||
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
|
||||
@ -74,36 +102,17 @@ cleanup() {
|
||||
trap cleanup EXIT
|
||||
|
||||
# Build the Ray start command based on the node role.
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# while workers connect to the head's address.
|
||||
RAY_START_CMD="ray start --block"
|
||||
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||
RAY_START_CMD+=" --head --port=6379"
|
||||
RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379"
|
||||
else
|
||||
|
||||
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||
fi
|
||||
|
||||
# Parse VLLM_HOST_IP from additional args if present.
|
||||
# This is needed for multi-NIC configurations where Ray needs explicit IP bindings.
|
||||
VLLM_HOST_IP=""
|
||||
for arg in "${ADDITIONAL_ARGS[@]}"; do
|
||||
if [[ $arg == "-e" ]]; then
|
||||
continue
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}"
|
||||
fi
|
||||
if [[ $arg == VLLM_HOST_IP=* ]]; then
|
||||
VLLM_HOST_IP="${arg#VLLM_HOST_IP=}"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Build Ray IP environment variables if VLLM_HOST_IP is set.
|
||||
# These variables ensure Ray binds to the correct network interface on multi-NIC systems.
|
||||
RAY_IP_VARS=()
|
||||
if [ -n "${VLLM_HOST_IP}" ]; then
|
||||
RAY_IP_VARS=(
|
||||
-e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
-e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}"
|
||||
)
|
||||
fi
|
||||
|
||||
# Launch the container with the assembled parameters.
|
||||
@ -118,6 +127,5 @@ docker run \
|
||||
--shm-size 10.24g \
|
||||
--gpus all \
|
||||
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||
"${RAY_IP_VARS[@]}" \
|
||||
"${ADDITIONAL_ARGS[@]}" \
|
||||
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
|
||||
|
||||
@ -112,7 +112,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Generate an SQL query to show the 'username' and 'email'from the 'users' table.",
|
||||
"content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.",
|
||||
}
|
||||
],
|
||||
"extra_body": {
|
||||
|
||||
@ -50,4 +50,5 @@ ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
model-hosting-container-standards >= 0.1.9, < 1.0.0
|
||||
mcp
|
||||
@ -1,2 +1,2 @@
|
||||
lmcache >= 0.3.10.post1
|
||||
lmcache
|
||||
nixl >= 0.7.1 # Required for disaggregated prefill
|
||||
|
||||
@ -20,13 +20,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ...utils import flat_product, multi_gpu_test
|
||||
|
||||
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||
is_blackwell = lambda: current_platform.is_device_capability_family(100)
|
||||
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||
|
||||
|
||||
class Matches(NamedTuple):
|
||||
attention_fusion: int = 0
|
||||
allreduce_fusion: int = 0
|
||||
rms_quant_norm_fusion: int = 0
|
||||
sequence_parallel: int = 0
|
||||
async_tp: int = 0
|
||||
|
||||
@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple):
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS_FP4: list[ModelBackendTestCase] = []
|
||||
MODELS_GROUP_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS: list[ModelBackendTestCase] = [] # tp-only
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@ -138,6 +140,17 @@ elif current_platform.is_rocm():
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
|
||||
|
||||
def has_cuda_graph_wrapper_metadata() -> bool:
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module = import_module("torch._inductor.utils")
|
||||
module.CUDAGraphWrapperMetadata # noqa B018
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
@ -145,7 +158,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"inductor_graph_partition",
|
||||
[
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not has_cuda_graph_wrapper_metadata(),
|
||||
reason="This test requires"
|
||||
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
|
||||
),
|
||||
),
|
||||
False,
|
||||
],
|
||||
)
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
@ -474,3 +500,79 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
compilation_config.compile_ranges_split_points = (
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
||||
)
|
||||
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_GROUP_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="Qwen/Qwen3-30B-A3B-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
matches=Matches(
|
||||
rms_quant_norm_fusion=48,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||
# Test rms norm+group quant_fp8 fusion
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: AttentionBackendEnum,
|
||||
matches: Matches,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
log_matches = re.findall(
|
||||
r"\[fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(log_matches) == 1, log_holder.text
|
||||
assert int(log_matches[0]) == matches.rms_quant_norm_fusion
|
||||
|
||||
@ -36,7 +36,7 @@ def get_test_models():
|
||||
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0"])
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
@ -54,6 +54,12 @@ def test_dynamic_shapes_compilation(
|
||||
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
||||
|
||||
if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("unbacked dynamic shapes do not add guards")
|
||||
|
||||
if evaluate_guards and use_aot_compile:
|
||||
pytest.skip("evaluate_guards requires use_aot_compile=0")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
@ -120,7 +126,7 @@ def test_model_specialization_with_evaluate_guards(
|
||||
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||
and evaluate_guards
|
||||
):
|
||||
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
|
||||
pytest.skip("evaluate_guards for backed does not work with aot_compile=1")
|
||||
|
||||
@support_torch_compile
|
||||
class ModelWithSizeCheck(torch.nn.Module):
|
||||
|
||||
@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
def __init__(self, head_dim=64, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": base},
|
||||
)
|
||||
|
||||
@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_init():
|
||||
"""Initialize the workspace manager for tests that need it.
|
||||
|
||||
This fixture initializes the workspace manager with a CUDA device
|
||||
if available, and resets it after the test completes. Tests that
|
||||
create a full vLLM engine should NOT use this fixture as the engine
|
||||
will initialize the workspace manager itself.
|
||||
"""
|
||||
from vllm.v1.worker.workspace import (
|
||||
init_workspace_manager,
|
||||
reset_workspace_manager,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
init_workspace_manager(device)
|
||||
yield
|
||||
reset_workspace_manager()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def dynamo_reset():
|
||||
yield
|
||||
@ -681,10 +702,16 @@ class HfRunner:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Encoder-decoder models return decoder_hidden_states instead of
|
||||
# hidden_states
|
||||
hidden_states = (
|
||||
getattr(output, "hidden_states", None) or output.decoder_hidden_states
|
||||
)
|
||||
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs)
|
||||
) = self._hidden_states_to_logprobs(hidden_states, num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
|
||||
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
276
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Normal file
@ -0,0 +1,276 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
from .eplb_utils import distributed_run, set_env_vars_and_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
num_layers: int
|
||||
num_experts: int
|
||||
num_local_experts: int
|
||||
num_topk: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_tokens: int
|
||||
|
||||
|
||||
def make_fused_moe_layer(
|
||||
rank: int,
|
||||
layer_idx: int,
|
||||
test_config: TestConfig,
|
||||
) -> FusedMoE:
|
||||
quant_config = None
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
quant_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
fml = FusedMoE(
|
||||
num_experts=test_config.num_experts,
|
||||
top_k=test_config.num_topk,
|
||||
hidden_size=test_config.hidden_size,
|
||||
intermediate_size=test_config.intermediate_size,
|
||||
prefix=f"dummy_layer_{layer_idx}",
|
||||
activation="silu",
|
||||
is_act_and_mul=True,
|
||||
params_dtype=torch.bfloat16,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
|
||||
nvfp4_fused_moe.create_weights(
|
||||
fml,
|
||||
test_config.num_local_experts,
|
||||
test_config.hidden_size,
|
||||
test_config.intermediate_size,
|
||||
params_dtype=torch.uint8,
|
||||
global_num_experts=test_config.num_experts,
|
||||
)
|
||||
|
||||
fml = fml.to(device)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
test_config.num_local_experts,
|
||||
test_config.intermediate_size,
|
||||
test_config.hidden_size,
|
||||
in_dtype=torch.bfloat16,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
fml.w13_weight.data = w1_q
|
||||
fml.w2_weight.data = w2_q
|
||||
|
||||
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
|
||||
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
|
||||
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
|
||||
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
|
||||
fml.w2_weight_scale.data = (
|
||||
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w2_weight_scale.data.dtype)
|
||||
fml.w13_weight_scale.data = (
|
||||
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
|
||||
).to(fml.w13_weight_scale.data.dtype)
|
||||
|
||||
nvfp4_fused_moe.process_weights_after_loading(fml)
|
||||
|
||||
fml.maybe_init_modular_kernel()
|
||||
|
||||
return fml
|
||||
|
||||
|
||||
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_dp_group().cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
fml_layers = [
|
||||
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
|
||||
for layer_idx in range(test_config.num_layers)
|
||||
]
|
||||
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
|
||||
|
||||
hidden_states = []
|
||||
router_logits = []
|
||||
for layer_idx in range(test_config.num_layers):
|
||||
hidden_states.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
router_logits.append(
|
||||
torch.randn(
|
||||
(test_config.num_tokens, test_config.num_experts),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
out_before_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_before_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
indices = torch.zeros(
|
||||
test_config.num_layers, test_config.num_experts, dtype=torch.long
|
||||
)
|
||||
for lidx in range(test_config.num_layers):
|
||||
indices[lidx] = torch.Tensor(range(test_config.num_experts))
|
||||
|
||||
shuffled_indices = torch.zeros_like(indices)
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
logical_to_physical_map_list = []
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
physical_to_logical_map = shuffled_indices[lidx].to(device)
|
||||
logical_to_physical_map = torch.empty(
|
||||
(num_global_experts,), dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map[physical_to_logical_map] = torch.arange(
|
||||
0, num_global_experts, dtype=torch.int32, device=device
|
||||
)
|
||||
logical_to_physical_map_list.append(
|
||||
logical_to_physical_map.reshape(num_global_experts, 1)
|
||||
)
|
||||
|
||||
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
|
||||
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
logical_replica_count = torch.ones(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
fml.enable_eplb = True
|
||||
fml.set_eplb_state(
|
||||
lidx,
|
||||
torch.zeros(
|
||||
(test_config.num_layers, num_global_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
),
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
out_after_shuffle = []
|
||||
with set_forward_context(
|
||||
{},
|
||||
num_tokens=test_config.num_tokens,
|
||||
num_tokens_across_dp=torch.tensor(
|
||||
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
|
||||
),
|
||||
vllm_config=vllm_config,
|
||||
):
|
||||
for lidx, fml in enumerate(fml_layers):
|
||||
out_after_shuffle.append(
|
||||
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
|
||||
)
|
||||
|
||||
for lidx in range(test_config.num_layers):
|
||||
torch.testing.assert_close(
|
||||
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("num_layers", [8])
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("hidden_size", [256])
|
||||
@pytest.mark.parametrize("intermediate_size", [256])
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("backend", ["latency", "throughput"])
|
||||
def test_eplb_fml(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_tokens: int,
|
||||
backend: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
|
||||
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
|
||||
num_local_experts = num_experts // world_size
|
||||
num_topk = 4
|
||||
|
||||
test_config = TestConfig(
|
||||
num_layers=num_layers,
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
num_topk=num_topk,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
distributed_run(
|
||||
_test_eplb_fml,
|
||||
world_size,
|
||||
test_config,
|
||||
)
|
||||
@ -1,21 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
|
||||
from openai.types.responses.response_output_item import McpCall
|
||||
from openai_harmony import Author, Message, Role, TextContent
|
||||
|
||||
from tests.entrypoints.openai.utils import verify_harmony_messages
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
auto_drop_analysis_messages,
|
||||
get_encoding,
|
||||
has_custom_tools,
|
||||
parse_chat_input_to_harmony_message,
|
||||
parse_chat_output,
|
||||
parse_input_to_harmony_message,
|
||||
parse_output_message,
|
||||
)
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""Tests for parse_input_to_harmony_message function."""
|
||||
class TestCommonParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are common to both Chat Completion
|
||||
parse_chat_input_to_harmony_message and Responsees API
|
||||
parse_input_to_harmony_message functions.
|
||||
"""
|
||||
|
||||
def test_assistant_message_with_tool_calls(self):
|
||||
@pytest.fixture(
|
||||
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
|
||||
)
|
||||
def parse_function(self, request):
|
||||
return request.param
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, parse_function):
|
||||
"""Test parsing assistant message with tool calls."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 2
|
||||
|
||||
@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
|
||||
assert messages[1].recipient == "functions.search_web"
|
||||
assert messages[1].content_type == "json"
|
||||
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self):
|
||||
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
|
||||
"""Test parsing assistant message with tool call having None arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].recipient == "functions.get_current_time"
|
||||
|
||||
def test_system_message(self, parse_function):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self, parse_function):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self, parse_function):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self, parse_function):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self, parse_function):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self, parse_function):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self, parse_function):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_array_content_with_missing_text(self, parse_function):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_function(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
|
||||
|
||||
class TestParseInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Responses API
|
||||
parse_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
"""Test parsing tool message with string content."""
|
||||
chat_msg = {
|
||||
@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.search_results"
|
||||
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.TOOL
|
||||
assert messages[0].author.name == "functions.empty_tool"
|
||||
assert messages[0].content[0].text == ""
|
||||
|
||||
def test_system_message(self):
|
||||
"""Test parsing system message."""
|
||||
chat_msg = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
class TestParseChatInputToHarmonyMessage:
|
||||
"""
|
||||
Tests for scenarios that are specific to the Chat Completion API
|
||||
parse_chat_input_to_harmony_message function.
|
||||
"""
|
||||
|
||||
assert len(messages) == 1
|
||||
# System messages are converted using Message.from_dict
|
||||
# which should preserve the role
|
||||
assert messages[0].author.role == Role.SYSTEM
|
||||
|
||||
def test_developer_message(self):
|
||||
"""Test parsing developer message."""
|
||||
chat_msg = {
|
||||
"role": "developer",
|
||||
"content": "Use concise language",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.DEVELOPER
|
||||
|
||||
def test_user_message_with_string_content(self):
|
||||
"""Test parsing user message with string content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "What's the weather in San Francisco?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "What's the weather in San Francisco?"
|
||||
|
||||
def test_user_message_with_array_content(self):
|
||||
"""Test parsing user message with array content."""
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "What's in this image? "},
|
||||
{"text": "Please describe it."},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == "What's in this image? "
|
||||
assert messages[0].content[1].text == "Please describe it."
|
||||
|
||||
def test_assistant_message_with_string_content(self):
|
||||
"""Test parsing assistant message with string content (no tool calls)."""
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.ASSISTANT
|
||||
assert messages[0].content[0].text == "Hello! How can I help you today?"
|
||||
|
||||
def test_pydantic_model_input(self):
|
||||
"""Test parsing Pydantic model input (has model_dump method)."""
|
||||
|
||||
class MockPydanticModel:
|
||||
def model_dump(self, exclude_none=True):
|
||||
return {
|
||||
"role": "user",
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
chat_msg = MockPydanticModel()
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].author.role == Role.USER
|
||||
assert messages[0].content[0].text == "Test message"
|
||||
|
||||
def test_message_with_empty_content(self):
|
||||
"""Test parsing message with empty string content."""
|
||||
def test_user_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_call_with_missing_function_fields(self):
|
||||
"""Test parsing tool call with missing name or arguments."""
|
||||
def test_user_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
def test_assistant_message_with_content_but_empty_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"content": "The answer is 4.",
|
||||
"reasoning": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": "The answer is 4.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_empty_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_reasoning_but_none_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"reasoning": "I'm thinking about the user's question.",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I'm thinking about the user's question.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_but_no_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {} # Missing both name and arguments
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0].recipient == "functions."
|
||||
assert messages[0].content[0].text == ""
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_array_content_with_missing_text(self):
|
||||
"""Test parsing array content where text field is missing."""
|
||||
def test_assistant_message_with_tool_calls_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "user",
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
|
||||
chat_msg = {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}',
|
||||
}
|
||||
}
|
||||
],
|
||||
"reasoning": "I should use the get_weather tool.",
|
||||
"content": "I'll call the tool.",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(chat_msg)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": "I'll call the tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": "I should use the get_weather tool.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": '{"location": "San Francisco"}',
|
||||
"content_type": "json",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_string_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "get_weather",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.get_weather",
|
||||
"content": "The weather in San Francisco is sunny, 72°F",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_array_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "search_results",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": [
|
||||
{}, # Missing text field
|
||||
{"text": "actual text"},
|
||||
{"type": "text", "text": "Result 1: "},
|
||||
{"type": "text", "text": "Result 2: "},
|
||||
{
|
||||
"type": "image",
|
||||
"url": "http://example.com/img.png",
|
||||
}, # Should be ignored
|
||||
{"type": "text", "text": "Result 3"},
|
||||
],
|
||||
}
|
||||
|
||||
messages = parse_input_to_harmony_message(chat_msg)
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert len(messages[0].content) == 2
|
||||
assert messages[0].content[0].text == ""
|
||||
assert messages[0].content[1].text == "actual text"
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.search_results",
|
||||
"content": "Result 1: Result 2: Result 3",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_empty_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "",
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
def test_tool_message_with_none_content(self):
|
||||
tool_id_names = {
|
||||
"call_123": "empty_tool",
|
||||
}
|
||||
chat_msg = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
messages = parse_chat_input_to_harmony_message(
|
||||
chat_msg, tool_id_names=tool_id_names
|
||||
)
|
||||
|
||||
verify_harmony_messages(
|
||||
messages,
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "functions.empty_tool",
|
||||
"content": "",
|
||||
"channel": "commentary",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestAutoDropAnalysisMessages:
|
||||
def test_no_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_analysis_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_multiple_analysis_messages_without_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_only_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
assert cleaned_messages == messages
|
||||
|
||||
def test_drops_one_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
def test_drops_all_analysis_messages_before_final_message(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the first 3 analysis messages
|
||||
assert cleaned_messages == messages[3:]
|
||||
|
||||
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking about the user's question."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I'm thinking even more."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "I should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 5."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped all those analysis messages
|
||||
assert len(cleaned_messages) == 2
|
||||
assert cleaned_messages[0].content[0].text == "The answer is 4."
|
||||
assert cleaned_messages[1].content[0].text == "The answer is 5."
|
||||
|
||||
def test_drops_non_assistant_analysis_messages(self) -> None:
|
||||
messages = [
|
||||
Message.from_role_and_content(
|
||||
Role.TOOL, "The tool thinks we should think harder."
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "The answer is 4."
|
||||
).with_channel("final"),
|
||||
]
|
||||
cleaned_messages = auto_drop_analysis_messages(messages)
|
||||
# Should have dropped the analysis message
|
||||
assert cleaned_messages == messages[1:]
|
||||
|
||||
|
||||
class TestParseChatOutput:
|
||||
def test_parse_chat_output_interrupted_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
|
||||
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm in the middle of thinking"
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
|
||||
"<|start|>assistant<|channel|>final"
|
||||
"<|message|>I'm in the middle of answering"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I'm thinking."
|
||||
assert final_content == "I'm in the middle of answering"
|
||||
|
||||
def test_parse_chat_output_complete_content(self) -> None:
|
||||
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
def test_parse_chat_output_complete_commentary(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning is None
|
||||
assert final_content == "I need to call some tools."
|
||||
|
||||
def test_parse_chat_output_complete_reasoning(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content is None
|
||||
|
||||
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
|
||||
harmony_str = (
|
||||
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
|
||||
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
|
||||
)
|
||||
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
reasoning, final_content, _ = parse_chat_output(token_ids)
|
||||
assert reasoning == "I've thought hard about this."
|
||||
assert final_content == "The answer is 4."
|
||||
|
||||
|
||||
class TestParseOutputMessage:
|
||||
|
||||
@ -80,10 +80,9 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, request_prompts, engine_prompts
|
||||
# return conversation, engine_prompts
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[[1, 2, 3]],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
|
||||
@ -79,9 +79,12 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
|
||||
|
||||
assert chunk_count > 0
|
||||
assert first_chunk is not None, "message_start chunk was never observed"
|
||||
assert first_chunk.usage is not None, "first chunk should include usage stats"
|
||||
assert first_chunk.usage["output_tokens"] == 0
|
||||
assert first_chunk.usage["input_tokens"] > 5
|
||||
assert first_chunk.message is not None, "first chunk should include message"
|
||||
assert first_chunk.message.usage is not None, (
|
||||
"first chunk should include usage stats"
|
||||
)
|
||||
assert first_chunk.message.usage.output_tokens == 0
|
||||
assert first_chunk.message.usage.input_tokens > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -11,13 +11,25 @@ import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .utils import (
|
||||
accumulate_streaming_response,
|
||||
verify_chat_response,
|
||||
verify_harmony_messages,
|
||||
)
|
||||
|
||||
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
|
||||
# Verify that data_parallel_rank defaults to None
|
||||
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
|
||||
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
|
||||
|
||||
|
||||
class TestServingChatWithHarmony:
|
||||
"""
|
||||
These tests ensure Chat Completion requests are being properly converted into
|
||||
Harmony messages and Harmony response messages back into Chat Completion responses.
|
||||
These tests are not exhaustive, but each one was created to cover a specific case
|
||||
that we got wrong but is now fixed.
|
||||
|
||||
Any changes to the tests and their expectations may result in changes to the
|
||||
accuracy of model prompting and responses generated. It is suggested to run
|
||||
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
|
||||
any impact of changes in how we prompt Harmony models.
|
||||
"""
|
||||
|
||||
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
|
||||
def stream(self, request) -> bool:
|
||||
"""Parameterize tests to run in both non-streaming and streaming modes."""
|
||||
return request.param
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_engine(self) -> AsyncLLM:
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
return mock_engine
|
||||
|
||||
@pytest.fixture()
|
||||
def serving_chat(self, mock_engine) -> OpenAIServingChat:
|
||||
chat = _build_serving_chat(mock_engine)
|
||||
chat.use_harmony = True
|
||||
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
|
||||
return chat
|
||||
|
||||
def mock_request_output_from_req_and_token_ids(
|
||||
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
|
||||
) -> RequestOutput:
|
||||
# Our tests don't use most fields, so just get the token ids correct
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
return RequestOutput(
|
||||
request_id=req.request_id,
|
||||
prompt=[],
|
||||
prompt_token_ids=[],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tools(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def weather_messages_start(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?",
|
||||
},
|
||||
]
|
||||
|
||||
async def generate_response_from_harmony_str(
|
||||
self,
|
||||
serving_chat: OpenAIServingChat,
|
||||
req: ChatCompletionRequest,
|
||||
harmony_str: str,
|
||||
stream: bool = False,
|
||||
) -> ChatCompletionResponse:
|
||||
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
|
||||
|
||||
async def result_generator():
|
||||
if stream:
|
||||
for token_id in harmony_token_ids:
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, [token_id]
|
||||
)
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, [], finished=True
|
||||
)
|
||||
else:
|
||||
yield self.mock_request_output_from_req_and_token_ids(
|
||||
req, harmony_token_ids, finished=True
|
||||
)
|
||||
|
||||
generator_func = (
|
||||
serving_chat.chat_completion_stream_generator
|
||||
if stream
|
||||
else serving_chat.chat_completion_full_generator
|
||||
)
|
||||
|
||||
result = generator_func(
|
||||
request=req,
|
||||
result_generator=result_generator(),
|
||||
request_id=req.request_id,
|
||||
model_name=req.model,
|
||||
conversation=[],
|
||||
tokenizer=get_tokenizer(req.model),
|
||||
request_metadata=RequestResponseMetadata(
|
||||
request_id=req.request_id,
|
||||
model_name=req.model,
|
||||
),
|
||||
)
|
||||
|
||||
if stream:
|
||||
return await accumulate_streaming_response(result)
|
||||
return await result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_chat(self, serving_chat, stream):
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "We need to think really hard about this."
|
||||
final_str = "The answer is 2."
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
# The analysis message should be dropped on subsequent inputs because
|
||||
# of the subsequent assistant message to the final channel.
|
||||
{"role": "assistant", "channel": "final", "content": final_str},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_response_with_content(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
commentary_str = "We'll call get_weather."
|
||||
tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
content=commentary_str,
|
||||
tool_calls=[("get_weather", tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"content": commentary_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_and_reasoning(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_tools_and_reasoning(
|
||||
self, serving_chat, stream, weather_tools, weather_messages_start
|
||||
):
|
||||
tools = weather_tools
|
||||
messages = list(weather_messages_start)
|
||||
|
||||
# Test the Harmony messages for the first turn's input
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer", "tool_definitions": ["get_weather"]},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the first turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
paris_tool_args_str = '{"location": "Paris"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
|
||||
)
|
||||
response = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", paris_tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the first turn as input to the second turn
|
||||
for choice in response.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the second turn's input
|
||||
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
|
||||
verify_harmony_messages(
|
||||
input_messages_2,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": paris_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the second turn's output
|
||||
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
|
||||
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
|
||||
response_2 = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req_2, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(response_2, content=paris_weather_str)
|
||||
|
||||
# Add the output messages from the second turn as input to the third turn
|
||||
for choice in response_2.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add a new user message for the third turn
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Boston today?",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the third turn's input
|
||||
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_3, _ = serving_chat._make_request_with_harmony(req_3)
|
||||
verify_harmony_messages(
|
||||
input_messages_3,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": paris_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "20 degrees Celsius",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": paris_weather_str,
|
||||
},
|
||||
{"role": "user", "content": messages[-1]["content"]},
|
||||
],
|
||||
)
|
||||
|
||||
# Test the Chat Completion response for the third turn's output
|
||||
reasoning_str = "I'll call get_weather."
|
||||
boston_tool_args_str = '{"location": "Boston"}'
|
||||
response_str = (
|
||||
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
|
||||
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
|
||||
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
|
||||
)
|
||||
response_3 = await self.generate_response_from_harmony_str(
|
||||
serving_chat, req, response_str, stream=stream
|
||||
)
|
||||
verify_chat_response(
|
||||
response_3,
|
||||
reasoning=reasoning_str,
|
||||
tool_calls=[("get_weather", boston_tool_args_str)],
|
||||
)
|
||||
|
||||
tool_call = response_3.choices[0].message.tool_calls[0]
|
||||
|
||||
# Add the output messages from the third turn as input to the fourth turn
|
||||
for choice in response_3.choices:
|
||||
messages.append(choice.message.model_dump(exclude_none=True))
|
||||
|
||||
# Add our tool output message
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "10 degrees Celsius",
|
||||
},
|
||||
)
|
||||
|
||||
# Test the Harmony messages for the fourth turn's input
|
||||
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages_4, _ = serving_chat._make_request_with_harmony(req_4)
|
||||
verify_harmony_messages(
|
||||
input_messages_4,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user"},
|
||||
{"role": "assistant"},
|
||||
{"role": "tool"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
},
|
||||
{"role": "user"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": reasoning_str,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "commentary",
|
||||
"recipient": "functions.get_weather",
|
||||
"content": boston_tool_args_str,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"author_name": "functions.get_weather",
|
||||
"channel": "commentary",
|
||||
"recipient": "assistant",
|
||||
"content": "10 degrees Celsius",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": "4",
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
# The reasoning that would have resulted in an analysis message is
|
||||
# dropped because of a later assistant message to the final channel.
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "final",
|
||||
"content": messages[1]["content"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning_empty_content(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": "",
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": messages[1]["reasoning"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's 2+2?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
|
||||
"content": [],
|
||||
},
|
||||
]
|
||||
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
|
||||
input_messages, _ = serving_chat._make_request_with_harmony(req)
|
||||
|
||||
verify_harmony_messages(
|
||||
input_messages,
|
||||
[
|
||||
{"role": "system"},
|
||||
{"role": "developer"},
|
||||
{"role": "user", "content": messages[0]["content"]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"channel": "analysis",
|
||||
"content": messages[1]["reasoning"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@ -10,7 +10,7 @@ import pytest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import (
|
||||
extract_tool_types,
|
||||
)
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
|
||||
|
||||
class MockConversationContext(ConversationContext):
|
||||
@ -237,7 +237,7 @@ class TestValidateGeneratorInput:
|
||||
"""Test _validate_generator_input with valid prompt length"""
|
||||
# Create an engine prompt with valid length (less than max_model_len)
|
||||
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids)
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids)
|
||||
|
||||
# Call the method
|
||||
result = serving_responses_instance._validate_generator_input(engine_prompt)
|
||||
@ -247,7 +247,7 @@ class TestValidateGeneratorInput:
|
||||
|
||||
# create an invalid engine prompt
|
||||
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
|
||||
|
||||
# Call the method
|
||||
result = serving_responses_instance._validate_generator_input(engine_prompt)
|
||||
|
||||
342
tests/entrypoints/openai/test_sparse_tensor_validation.py
Normal file
342
tests/entrypoints/openai/test_sparse_tensor_validation.py
Normal file
@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Sparse tensor validation in embedding APIs.
|
||||
|
||||
Tests verify that malicious sparse tensors are rejected before they can trigger
|
||||
out-of-bounds memory writes during to_dense() operations.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.multimodal.audio import AudioEmbeddingMediaIO
|
||||
from vllm.multimodal.image import ImageEmbeddingMediaIO
|
||||
|
||||
|
||||
def _encode_tensor(tensor: torch.Tensor) -> bytes:
|
||||
"""Helper to encode a tensor as base64 bytes."""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
return base64.b64encode(buffer.read())
|
||||
|
||||
|
||||
def _create_malicious_sparse_tensor() -> torch.Tensor:
|
||||
"""
|
||||
Create a malicious sparse COO tensor with out-of-bounds indices.
|
||||
|
||||
This tensor has indices that point beyond the declared shape, which would
|
||||
cause an out-of-bounds write when converted to dense format without
|
||||
validation.
|
||||
"""
|
||||
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
|
||||
indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape
|
||||
values = torch.tensor([1.0])
|
||||
shape = (3, 3)
|
||||
|
||||
# Create sparse tensor (this will be invalid)
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
|
||||
return sparse_tensor
|
||||
|
||||
|
||||
def _create_valid_sparse_tensor() -> torch.Tensor:
|
||||
"""Create a valid sparse COO tensor for baseline testing."""
|
||||
indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
values = torch.tensor([1.0, 2.0, 3.0])
|
||||
shape = (3, 3)
|
||||
|
||||
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
|
||||
return sparse_tensor
|
||||
|
||||
|
||||
def _create_valid_dense_tensor() -> torch.Tensor:
|
||||
"""Create a valid dense tensor for baseline testing."""
|
||||
return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size)
|
||||
|
||||
|
||||
class TestPromptEmbedsValidation:
|
||||
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self, model_config):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = renderer.load_prompt_embeds(encoded)
|
||||
assert len(result) == 1
|
||||
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception (sparse tensors remain sparse)
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_sparse.shape
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self, model_config):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
# Error should indicate sparse tensor validation failure
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_extremely_large_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with extremely large indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with indices far beyond reasonable bounds
|
||||
indices = torch.tensor([[999999], [999999]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (10, 10)
|
||||
|
||||
malicious_tensor = torch.sparse_coo_tensor(
|
||||
indices, values, shape, dtype=torch.float32
|
||||
)
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
def test_negative_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with negative indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with negative indices
|
||||
indices = torch.tensor([[-1], [-1]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (10, 10)
|
||||
|
||||
malicious_tensor = torch.sparse_coo_tensor(
|
||||
indices, values, shape, dtype=torch.float32
|
||||
)
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
|
||||
|
||||
class TestImageEmbedsValidation:
|
||||
"""Test sparse tensor validation in image embeddings (Chat API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception (sparse tensors remain sparse)
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_sparse.shape
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_load_bytes_validates(self):
|
||||
"""Security: Validation should also work for load_bytes method."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.save(malicious_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
|
||||
class TestAudioEmbedsValidation:
|
||||
"""Test sparse tensor validation in audio embeddings (Chat API)."""
|
||||
|
||||
def test_valid_dense_tensor_accepted(self):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should be converted successfully."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
valid_sparse = _create_valid_sparse_tensor()
|
||||
encoded = _encode_tensor(valid_sparse)
|
||||
|
||||
# Should not raise any exception
|
||||
result = io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
assert result.is_sparse is False
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
io_handler.load_base64("", encoded.decode("utf-8"))
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
|
||||
|
||||
def test_load_bytes_validates(self):
|
||||
"""Security: Validation should also work for load_bytes method."""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.save(malicious_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
|
||||
class TestSparseTensorValidationIntegration:
|
||||
"""
|
||||
These tests verify the complete attack chain is blocked at all entry points.
|
||||
"""
|
||||
|
||||
def test_attack_scenario_completions_api(self, model_config):
|
||||
"""
|
||||
Simulate a complete attack through the Completions API.
|
||||
|
||||
Attack scenario:
|
||||
1. Attacker crafts malicious sparse tensor
|
||||
2. Encodes it as base64
|
||||
3. Sends to /v1/completions with prompt_embeds parameter
|
||||
4. Server should reject before memory corruption occurs
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Step 1-2: Attacker creates malicious payload
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
# Step 3-4: Server processes and should reject
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(attack_payload)
|
||||
|
||||
def test_attack_scenario_chat_api_image(self):
|
||||
"""
|
||||
Simulate attack through Chat API with image_embeds.
|
||||
|
||||
Verifies the image embeddings path is protected.
|
||||
"""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_attack_scenario_chat_api_audio(self):
|
||||
"""
|
||||
Simulate attack through Chat API with audio_embeds.
|
||||
|
||||
Verifies the audio embeddings path is protected.
|
||||
"""
|
||||
io_handler = AudioEmbeddingMediaIO()
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_multiple_valid_embeddings_in_batch(self, model_config):
|
||||
"""
|
||||
Regression test: Multiple valid embeddings should still work.
|
||||
|
||||
Ensures the fix doesn't break legitimate batch processing.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensors = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should process all without error
|
||||
result = renderer.load_prompt_embeds(valid_tensors)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_mixed_valid_and_malicious_rejected(self, model_config):
|
||||
"""
|
||||
Security: Batch with one malicious tensor should be rejected.
|
||||
|
||||
Even if most tensors are valid, a single malicious one should
|
||||
cause rejection of the entire batch.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
mixed_batch = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should fail on the malicious tensor
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(mixed_batch)
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
"""Mock ModelConfig for testing."""
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
return ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
enable_prompt_embeds=True, # Required for prompt embeds tests
|
||||
)
|
||||
190
tests/entrypoints/openai/utils.py
Normal file
190
tests/entrypoints/openai/utils.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
async def accumulate_streaming_response(
|
||||
stream_generator: AsyncGenerator[str, None],
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
|
||||
|
||||
This helper parses the SSE format and builds up the complete response
|
||||
by combining all the delta chunks.
|
||||
"""
|
||||
accumulated_content = ""
|
||||
accumulated_reasoning = None
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
role = None
|
||||
finish_reason = None
|
||||
response_id = None
|
||||
created = None
|
||||
model = None
|
||||
index = 0
|
||||
|
||||
async for chunk_str in stream_generator:
|
||||
# Skip empty lines and [DONE] marker
|
||||
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
|
||||
continue
|
||||
|
||||
# Parse SSE format: "data: {json}\n\n"
|
||||
if chunk_str.startswith("data: "):
|
||||
json_str = chunk_str[6:].strip()
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
|
||||
chunk = ChatCompletionStreamResponse(**chunk_data)
|
||||
|
||||
# Store metadata from first chunk
|
||||
if response_id is None:
|
||||
response_id = chunk.id
|
||||
created = chunk.created
|
||||
model = chunk.model
|
||||
|
||||
# Process each choice in the chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.delta.role:
|
||||
role = choice.delta.role
|
||||
if choice.delta.content:
|
||||
accumulated_content += choice.delta.content
|
||||
if choice.delta.reasoning:
|
||||
if accumulated_reasoning is None:
|
||||
accumulated_reasoning = ""
|
||||
accumulated_reasoning += choice.delta.reasoning
|
||||
if choice.delta.tool_calls:
|
||||
# Accumulate tool calls
|
||||
for tool_call_delta in choice.delta.tool_calls:
|
||||
# Find or create the tool call at this index
|
||||
while len(accumulated_tool_calls) <= tool_call_delta.index:
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
)
|
||||
|
||||
if tool_call_delta.id:
|
||||
accumulated_tool_calls[tool_call_delta.index]["id"] = (
|
||||
tool_call_delta.id
|
||||
)
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
accumulated_tool_calls[tool_call_delta.index][
|
||||
"function"
|
||||
]["name"] += tool_call_delta.function.name
|
||||
if tool_call_delta.function.arguments:
|
||||
accumulated_tool_calls[tool_call_delta.index][
|
||||
"function"
|
||||
]["arguments"] += tool_call_delta.function.arguments
|
||||
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if choice.index is not None:
|
||||
index = choice.index
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Build the final message
|
||||
message_kwargs = {
|
||||
"role": role or "assistant",
|
||||
"content": accumulated_content if accumulated_content else None,
|
||||
"reasoning": accumulated_reasoning,
|
||||
}
|
||||
|
||||
# Only include tool_calls if there are any
|
||||
if accumulated_tool_calls:
|
||||
message_kwargs["tool_calls"] = [
|
||||
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
|
||||
for tc in accumulated_tool_calls
|
||||
]
|
||||
|
||||
message = ChatMessage(**message_kwargs)
|
||||
|
||||
# Build the final response
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
|
||||
# Create usage info (with dummy values for tests)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=response_id or "chatcmpl-test",
|
||||
object="chat.completion",
|
||||
created=created or 0,
|
||||
model=model or "test-model",
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def verify_harmony_messages(
|
||||
messages: list[Any], expected_messages: list[dict[str, Any]]
|
||||
):
|
||||
assert len(messages) == len(expected_messages)
|
||||
for msg, expected in zip(messages, expected_messages):
|
||||
if "role" in expected:
|
||||
assert msg.author.role == expected["role"]
|
||||
if "author_name" in expected:
|
||||
assert msg.author.name == expected["author_name"]
|
||||
if "channel" in expected:
|
||||
assert msg.channel == expected["channel"]
|
||||
if "recipient" in expected:
|
||||
assert msg.recipient == expected["recipient"]
|
||||
if "content" in expected:
|
||||
assert msg.content[0].text == expected["content"]
|
||||
if "content_type" in expected:
|
||||
assert msg.content_type == expected["content_type"]
|
||||
if "tool_definitions" in expected:
|
||||
# Check that the tool definitions match the expected list of tool names
|
||||
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
|
||||
assert actual_tools == expected["tool_definitions"]
|
||||
|
||||
|
||||
def verify_chat_response(
|
||||
response: ChatCompletionResponse,
|
||||
content: str | None = None,
|
||||
reasoning: str | None = None,
|
||||
tool_calls: list[tuple[str, str]] | None = None,
|
||||
):
|
||||
assert len(response.choices) == 1
|
||||
message = response.choices[0].message
|
||||
|
||||
if content is not None:
|
||||
assert message.content == content
|
||||
else:
|
||||
assert not message.content
|
||||
|
||||
if reasoning is not None:
|
||||
assert message.reasoning == reasoning
|
||||
else:
|
||||
assert not message.reasoning
|
||||
|
||||
if tool_calls:
|
||||
assert message.tool_calls is not None
|
||||
assert len(message.tool_calls) == len(tool_calls)
|
||||
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
|
||||
assert tc.function.name == expected_name
|
||||
assert tc.function.arguments == expected_args
|
||||
else:
|
||||
assert not message.tool_calls
|
||||
@ -29,7 +29,8 @@ from vllm.multimodal.utils import (
|
||||
encode_image_base64,
|
||||
encode_video_base64,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer, get_tokenizer
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
@ -796,9 +797,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
@ -825,10 +830,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
# Should have audio in mm_data as None (UUID provided)
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert mm_data["audio"] is None
|
||||
assert isinstance(mm_data["audio"], list)
|
||||
assert len(mm_data["audio"]) == 1
|
||||
assert mm_data["audio"][0] is None
|
||||
|
||||
# UUID should be recorded
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||
|
||||
|
||||
@ -1121,10 +1127,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
mm_data = await mm_future
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that empty dictionary for image_embeds is handled without errors."""
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_embeds", "image_embeds": {}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains an empty dictionary of embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == 0
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that multiple dictionaries for image_embeds is handled without errors."""
|
||||
# Create two sample image embedding tensors
|
||||
batch_size = 2
|
||||
image_embedding_1 = torch.randn(batch_size, 256, 1024)
|
||||
image_embedding_2 = torch.randn(batch_size, 3)
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embedding_1": tensor2base64(p),
|
||||
"image_embedding_2": tensor2base64(i),
|
||||
},
|
||||
}
|
||||
for p, i in zip(image_embedding_1, image_embedding_2)
|
||||
]
|
||||
+ [
|
||||
{"type": "text", "text": "Describe these two images."},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains a dictionary of multi-embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == batch_size
|
||||
|
||||
# Verify each embedding has the correct shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
|
||||
@ -32,8 +32,8 @@ def cal_diff(
|
||||
|
||||
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
||||
"Cutlass MLA Requires compute capability of 10 or above."
|
||||
if not current_platform.is_device_capability(100)
|
||||
"Cutlass MLA Requires compute capability of 100 or above."
|
||||
if not current_platform.is_device_capability_family(100)
|
||||
else "Cutlass MLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 1e-1, 2e-1
|
||||
rtol, atol = 3e-1, 4e-1
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@ -22,6 +23,10 @@ QDTYPES = (
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
# 0: use 2D kernel for decode
|
||||
# 8: use 3D kernel for decode
|
||||
SEQ_THRESHOLD_3D_VALUES = [0, 8]
|
||||
|
||||
|
||||
def ref_paged_attn(
|
||||
query: torch.Tensor,
|
||||
@ -92,6 +97,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attn(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
@ -103,6 +109,7 @@ def test_triton_unified_attn(
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
seq_threshold_3D: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
@ -152,6 +159,21 @@ def test_triton_unified_attn(
|
||||
k_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.rand(scale_shape, dtype=torch.float32)
|
||||
|
||||
num_par_softmax_segments = 16
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
unified_attention(
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
@ -169,6 +191,11 @@ def test_triton_unified_attn(
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
|
||||
@ -116,7 +116,6 @@ def test_mrope(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
@ -185,7 +184,6 @@ def test_mrope_torch_compile_tracing(
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=config.rope_parameters,
|
||||
|
||||
@ -83,8 +83,12 @@ def test_rotary_embedding(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
@ -150,9 +154,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
@ -177,9 +181,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
|
||||
@ -594,7 +594,8 @@ def make_modular_kernel(
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128]
|
||||
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
def test_batched_deepgemm_vs_triton(
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
|
||||
):
|
||||
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||
|
||||
|
||||
@ -248,6 +248,7 @@ def test_fused_moe_batched_experts(
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
|
||||
@ -137,7 +137,7 @@ def setup_cuda():
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8(
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe(
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -342,6 +343,9 @@ def _deep_ep_moe(
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode"
|
||||
@ -437,6 +441,7 @@ def test_deep_ep_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = True
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
@ -107,6 +108,19 @@ class TestData:
|
||||
layer.w2_input_scale = a2_scale
|
||||
layer.w13_weight_scale = w13_weight_scale
|
||||
layer.w2_weight_scale = w2_weight_scale
|
||||
# Setup dummy config.
|
||||
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
|
||||
tp_size=1,
|
||||
pcp_size=1,
|
||||
dp_size=1,
|
||||
ep_size=1,
|
||||
tp_rank=1,
|
||||
pcp_rank=1,
|
||||
dp_rank=1,
|
||||
ep_rank=1,
|
||||
use_ep=False,
|
||||
all2all_backend="naive",
|
||||
)
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
@ -206,6 +220,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
|
||||
@ -51,7 +51,14 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
|
||||
@ -269,7 +269,7 @@ class Case:
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
@ -77,6 +78,10 @@ def rank_worker(
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
# Initialize workspace manager in child process
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
|
||||
@ -209,6 +209,7 @@ def test_oai_triton_moe(
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
unfused: bool,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
(
|
||||
|
||||
@ -231,6 +231,7 @@ def test_fused_moe(
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
|
||||
@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
) >= version.parse("0.8.99")
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability(100)
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
HOPPER_MXFP4_BF16_AVAILABLE = (
|
||||
@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and current_platform.is_device_capability_family(100)
|
||||
and has_flashinfer()
|
||||
),
|
||||
reason="NVIDIA GPU sm100 and flashinfer are required for this test",
|
||||
|
||||
@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -181,6 +182,7 @@ def test_fused_moe_batched_experts(
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
@ -863,6 +865,9 @@ def _pplx_test_loop(
|
||||
make_weights: bool,
|
||||
test_fn: Callable,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
|
||||
@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant(
|
||||
if quant_dtype == torch.int8
|
||||
else torch.finfo(quant_dtype)
|
||||
)
|
||||
qtype_traits_max = (
|
||||
ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else qtype_traits.max
|
||||
)
|
||||
qtype_traits_min = (
|
||||
-ROCM_FP8FNUZ_MAX
|
||||
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
|
||||
else qtype_traits.min
|
||||
use_fp8fnuz = (
|
||||
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
|
||||
)
|
||||
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
|
||||
@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.randint(
|
||||
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
qzeros = torch.randint(
|
||||
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))
|
||||
|
||||
@ -18,7 +18,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -62,7 +62,7 @@ def test_quantfp8_group_functionality(
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
@ -71,7 +71,7 @@ def test_quantfp8_group_functionality(
|
||||
assert scales_cuda.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# Verify CUDA/native consistency
|
||||
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
|
||||
torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
|
||||
|
||||
# Quantized values should mostly match
|
||||
diff_count = (x_quant_cuda != x_quant_native).sum().item()
|
||||
|
||||
@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ScaledMM kernel selection logic (CPU-only)
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_is_supported_is_abstract():
|
||||
"""Test that is_supported() is properly defined as abstract."""
|
||||
assert issubclass(ScaledMMLinearKernel, ABC)
|
||||
assert hasattr(ScaledMMLinearKernel, "is_supported")
|
||||
|
||||
|
||||
def test_cpu_kernel_implements_is_supported():
|
||||
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
|
||||
"CPUScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
|
||||
CPUScaledMMLinearKernel.is_supported
|
||||
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
# Verify it can be called as a classmethod
|
||||
result, reason = CPUScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
|
||||
|
||||
def test_aiter_kernel_implements_is_supported():
|
||||
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
|
||||
"AiterScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(
|
||||
AiterScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
|
||||
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
)
|
||||
# Verify it can be called as a classmethod
|
||||
# (will return False on CPU, which is expected)
|
||||
result, reason = AiterScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
# On CPU, it should return False with a reason about requiring ROCm
|
||||
# This validates the method works correctly even on non-ROCm platforms
|
||||
|
||||
|
||||
def test_cpu_kernel_accepts_all_configs():
|
||||
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
|
||||
configs = [
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=False,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=True,
|
||||
),
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=True,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=False,
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
|
||||
assert can_impl, (
|
||||
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
)
|
||||
@ -0,0 +1 @@
|
||||
{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other.", "(B) To indicate that language cannot express clearly, satirizing the inversion of black and white in the world"], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645], [5349, 8, 2014, 13216, 429, 4128, 4157, 3158, 9355, 11, 7578, 404, 4849, 279, 46488, 315, 3691, 323, 4158, 304, 279, 1879, 151645, 151671]]}
|
||||
@ -0,0 +1 @@
|
||||
{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]}
|
||||
@ -10,7 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolParser,
|
||||
)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@ -68,3 +68,34 @@ def test_modernbert_models(
|
||||
hf_output = torch.tensor(hf_output).cpu().float()
|
||||
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@torch.inference_mode
|
||||
def test_auto_conversion(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||
) as hf_model:
|
||||
tokenizer = hf_model.tokenizer
|
||||
hf_outputs = []
|
||||
for prompt in example_prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = hf_model.wrap_device(inputs)
|
||||
output = hf_model.model(**inputs)
|
||||
hf_outputs.append(softmax(output.logits[0]))
|
||||
|
||||
# check logits difference
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output).cpu().float()
|
||||
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||
assert torch.allclose(hf_output, vllm_output, atol=1e-2)
|
||||
|
||||
142
tests/models/multimodal/generation/test_audioflamingo3.py
Normal file
142
tests/models/multimodal/generation/test_audioflamingo3.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL_NAME = "nvidia/audio-flamingo-3-hf"
|
||||
|
||||
|
||||
def get_fixture_path(filename):
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__), "../../fixtures/audioflamingo3", filename
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# Check if the model is supported by the current transformers version
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
)
|
||||
return llm
|
||||
except Exception as e:
|
||||
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
|
||||
|
||||
|
||||
def test_single_generation(llm):
|
||||
fixture_path = get_fixture_path("expected_results_single.json")
|
||||
if not os.path.exists(fixture_path):
|
||||
pytest.skip(f"Fixture not found: {fixture_path}")
|
||||
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
generated_text = outputs[0].outputs[0].text.strip()
|
||||
|
||||
expected_text = expected["transcriptions"][0]
|
||||
|
||||
assert expected_text in generated_text or generated_text in expected_text
|
||||
|
||||
|
||||
def test_batched_generation(llm):
|
||||
fixture_path = get_fixture_path("expected_results_batched.json")
|
||||
if not os.path.exists(fixture_path):
|
||||
pytest.skip(f"Fixture not found: {fixture_path}")
|
||||
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
items = [
|
||||
{
|
||||
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav",
|
||||
"question": "What is surprising about the relationship "
|
||||
"between the barking and the music?",
|
||||
"expected_idx": 0,
|
||||
},
|
||||
{
|
||||
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
|
||||
"question": (
|
||||
"Why is the philosopher's name mentioned in the lyrics? "
|
||||
"(A) To express a sense of nostalgia "
|
||||
"(B) To indicate that language cannot express clearly, "
|
||||
"satirizing the inversion of black and white in the world "
|
||||
"(C) To add depth and complexity to the lyrics "
|
||||
"(D) To showcase the wisdom and influence of the philosopher"
|
||||
),
|
||||
"expected_idx": 1,
|
||||
},
|
||||
]
|
||||
|
||||
conversations = []
|
||||
for item in items:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": item["audio_url"]}},
|
||||
{"type": "text", "text": item["question"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
conversations.append(messages)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=conversations,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
expected_text = expected["transcriptions"][i]
|
||||
|
||||
assert expected_text in generated_text or generated_text in expected_text
|
||||
@ -0,0 +1,434 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Consolidated test for ViT attention backend functionality across multiple models.
|
||||
|
||||
This test validates that each multimodal model can successfully generate outputs
|
||||
using different ViT attention backends. Tests are parametrized by model and backend.
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....utils import create_new_process_for_each_test
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
||||
# Dots.OCR prompt from official repository
|
||||
# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3
|
||||
# ruff: noqa: E501
|
||||
DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
||||
|
||||
1. Bbox format: [x1, y1, x2, y2]
|
||||
|
||||
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
|
||||
|
||||
3. Text Extraction & Formatting Rules:
|
||||
- Picture: For the 'Picture' category, the text field should be omitted.
|
||||
- Formula: Format its text as LaTeX.
|
||||
- Table: Format its text as HTML.
|
||||
- All Others (Text, Title, etc.): Format their text as Markdown.
|
||||
|
||||
4. Constraints:
|
||||
- The output text must be the original text from the image, with no translation.
|
||||
- All layout elements must be sorted according to human reading order.
|
||||
|
||||
5. Final Output: The entire output must be a single JSON object.
|
||||
"""
|
||||
|
||||
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
|
||||
|
||||
# Model configurations
|
||||
MODEL_CONFIGS: dict[str, dict[str, Any]] = {
|
||||
"dots_ocr": {
|
||||
"model_name": "rednote-hilab/dots.ocr",
|
||||
"interface": "llm_chat",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 1,
|
||||
"limit_mm_per_prompt": {"image": 1},
|
||||
"sampling_params": {
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 16384,
|
||||
"top_p": 0.9,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"use_specific_image": "stop_sign",
|
||||
"prompt_builder": "build_dots_ocr_prompt",
|
||||
"output_validator": lambda x: len(x) > 10 and "stop" in x.lower(),
|
||||
},
|
||||
"ernie45_vl": {
|
||||
"model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 16384,
|
||||
"max_num_seqs": 2,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"glm4_1v": {
|
||||
"model_name": "zai-org/GLM-4.1V-9B-Thinking",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 2,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"keye_vl": {
|
||||
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 8192,
|
||||
"max_num_seqs": 5,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"supported_backends": {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"ovis2_5": {
|
||||
"model_name": "AIDC-AI/Ovis2.5-2B",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 8192,
|
||||
"max_num_seqs": 2,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 256,
|
||||
"stop_token_ids": None,
|
||||
},
|
||||
"prompt_builder": "build_ovis_prompt",
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"qwen2_5_vl": {
|
||||
"model_name": "Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"interface": "vllm_runner",
|
||||
"media_type": "video",
|
||||
"max_model_len": 4000,
|
||||
"max_num_seqs": 1,
|
||||
"limit_mm_per_prompt": {"video": 1},
|
||||
"sampling_params": {
|
||||
"max_tokens": 128,
|
||||
},
|
||||
"runner_kwargs": {
|
||||
"runner": "generate",
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
"video_params": {
|
||||
"num_frames": 16,
|
||||
"pruning_rates": [0.0, 0.75],
|
||||
},
|
||||
},
|
||||
"qwen2_5_omni": {
|
||||
"model_name": "Qwen/Qwen2.5-Omni-3B",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 2,
|
||||
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
|
||||
"sampling_params": {
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.95,
|
||||
"top_k": 20,
|
||||
"max_tokens": 16384,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
"qwen3_omni": {
|
||||
"model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
"interface": "llm_generate",
|
||||
"max_model_len": 32768,
|
||||
"max_num_seqs": 2,
|
||||
"limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3},
|
||||
"sampling_params": {
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.95,
|
||||
"top_k": 20,
|
||||
"max_tokens": 16384,
|
||||
},
|
||||
"use_processor": True,
|
||||
"question": "What is the content of each image?",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Prompt builder functions
|
||||
def build_dots_ocr_prompt(images, config):
|
||||
"""Build Dots.OCR specific prompt with OCR instructions."""
|
||||
# Use only stop_sign image for Dots.OCR
|
||||
image = images[0] # Already filtered to stop_sign
|
||||
|
||||
image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"
|
||||
|
||||
placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_processor_prompt(images, config):
|
||||
"""Build prompt using AutoProcessor.apply_chat_template()."""
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
config["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
|
||||
]
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": config["question"]},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
return processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
|
||||
def build_ovis_prompt(images, config):
|
||||
"""Build Ovis2.5 specific prompt with custom format."""
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
|
||||
]
|
||||
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
|
||||
return (
|
||||
f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
|
||||
def build_qwen2_5_video_prompt():
|
||||
"""Build Qwen2.5-VL video prompt with EVS placeholder."""
|
||||
return (
|
||||
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n{VIDEO_PLACEHOLDER}"
|
||||
"Describe this video with a short sentence (no more than 20 words)"
|
||||
"<|im_end|><|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
|
||||
# Handler functions
|
||||
def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets):
|
||||
"""Standard LLM.generate() interface handler."""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
# Build prompt
|
||||
if config.get("use_processor"):
|
||||
prompt = build_processor_prompt(images, config)
|
||||
else:
|
||||
prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt")
|
||||
prompt_builder = globals()[prompt_builder_name]
|
||||
prompt = prompt_builder(images, config)
|
||||
|
||||
# Determine limit_mm_per_prompt
|
||||
limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)})
|
||||
|
||||
# Create engine
|
||||
engine_args = EngineArgs(
|
||||
model=config["model_name"],
|
||||
trust_remote_code=True,
|
||||
max_model_len=config["max_model_len"],
|
||||
max_num_seqs=config["max_num_seqs"],
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
engine_dict = asdict(engine_args) | {"seed": 42}
|
||||
llm = LLM(**engine_dict)
|
||||
|
||||
# Generate
|
||||
sampling_params = SamplingParams(**config["sampling_params"])
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": images},
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
# Validate
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
validator = config.get("output_validator", lambda x: len(x) > 10)
|
||||
assert validator(generated_text), (
|
||||
f"Validation failed for {config['model_name']}: {generated_text}"
|
||||
)
|
||||
|
||||
|
||||
def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets):
|
||||
"""LLM.chat() interface handler for Dots.OCR."""
|
||||
# Filter to stop_sign image only
|
||||
stop_sign_image = [
|
||||
asset.pil_image for asset in image_assets if asset.name == "stop_sign"
|
||||
][0]
|
||||
|
||||
# Build messages
|
||||
messages = build_dots_ocr_prompt([stop_sign_image], config)
|
||||
|
||||
# Create engine
|
||||
engine_args = EngineArgs(
|
||||
model=config["model_name"],
|
||||
trust_remote_code=True,
|
||||
max_model_len=config["max_model_len"],
|
||||
max_num_seqs=config["max_num_seqs"],
|
||||
limit_mm_per_prompt=config["limit_mm_per_prompt"],
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
engine_dict = asdict(engine_args) | {"seed": 42}
|
||||
llm = LLM(**engine_dict)
|
||||
|
||||
# Generate using chat
|
||||
sampling_params = SamplingParams(**config["sampling_params"])
|
||||
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
|
||||
|
||||
# Validate
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
validator = config.get("output_validator", lambda x: len(x) > 10)
|
||||
assert validator(generated_text), (
|
||||
f"Validation failed for {config['model_name']}: {generated_text}"
|
||||
)
|
||||
|
||||
|
||||
def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
||||
"""Video test with EVS (Efficient Video Sampling) handler."""
|
||||
for pruning_rate in config["video_params"]["pruning_rates"]:
|
||||
num_frames = config["video_params"]["num_frames"]
|
||||
|
||||
# Sample frames from video
|
||||
sampled_vids = [
|
||||
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||
for asset in video_assets
|
||||
]
|
||||
|
||||
# Build prompt and prepare video
|
||||
prompt = build_qwen2_5_video_prompt()
|
||||
prompts = [prompt]
|
||||
videos = [sampled_vids[0]]
|
||||
|
||||
# Run with vllm_runner context manager
|
||||
with vllm_runner(
|
||||
config["model_name"],
|
||||
max_model_len=config["max_model_len"],
|
||||
max_num_seqs=config["max_num_seqs"],
|
||||
limit_mm_per_prompt=config["limit_mm_per_prompt"],
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=pruning_rate,
|
||||
mm_encoder_attn_backend=mm_encoder_attn_backend,
|
||||
hf_overrides=dummy_hf_overrides,
|
||||
load_format="dummy",
|
||||
**config["runner_kwargs"],
|
||||
) as vllm_model:
|
||||
outputs = vllm_model.generate_greedy(
|
||||
prompts,
|
||||
config["sampling_params"]["max_tokens"],
|
||||
videos=videos,
|
||||
)
|
||||
|
||||
# Validate output
|
||||
assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}"
|
||||
output_ids, output_text = outputs[0]
|
||||
assert len(output_ids) > 0, "Generated no output IDs"
|
||||
assert len(output_text) > 0, "Generated empty text"
|
||||
assert isinstance(output_text, str), (
|
||||
f"Output is not string: {type(output_text)}"
|
||||
)
|
||||
|
||||
|
||||
# Main test function
|
||||
@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys()))
|
||||
@pytest.mark.parametrize(
|
||||
"mm_encoder_attn_backend",
|
||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_vit_backend_functionality(
|
||||
model_key: str,
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | None,
|
||||
image_assets,
|
||||
video_assets,
|
||||
vllm_runner,
|
||||
request,
|
||||
):
|
||||
"""Test ViT attention backend functionality for multimodal models.
|
||||
|
||||
This test validates that each model can successfully generate outputs
|
||||
using different ViT attention backends. The test:
|
||||
1. Filters unsupported backends per model
|
||||
2. Applies appropriate GPU marks
|
||||
3. Routes to the correct test handler based on interface
|
||||
4. Validates output meets minimum requirements
|
||||
"""
|
||||
config = MODEL_CONFIGS[model_key]
|
||||
|
||||
# Step 1: Backend filtering
|
||||
if (
|
||||
"supported_backends" in config
|
||||
and mm_encoder_attn_backend is not None
|
||||
and mm_encoder_attn_backend not in config["supported_backends"]
|
||||
):
|
||||
pytest.skip(
|
||||
f"{model_key} does not support {mm_encoder_attn_backend} backend now."
|
||||
)
|
||||
|
||||
# Step 2: Apply GPU marks dynamically
|
||||
if "gpu_marks" in config:
|
||||
for mark in config["gpu_marks"]:
|
||||
request.applymarker(mark)
|
||||
|
||||
# Step 3: Route to appropriate handler
|
||||
if config.get("media_type") == "video":
|
||||
run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner)
|
||||
elif config["interface"] == "llm_chat":
|
||||
run_llm_chat_test(config, mm_encoder_attn_backend, image_assets)
|
||||
elif config["interface"] == "llm_generate":
|
||||
run_llm_generate_test(config, mm_encoder_attn_backend, image_assets)
|
||||
else:
|
||||
raise ValueError(f"Unknown interface: {config['interface']}")
|
||||
@ -9,7 +9,7 @@ from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ....conftest import AudioTestAssets
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@ -1,150 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import librosa
|
||||
import pytest
|
||||
from transformers import AutoModelForSpeechSeq2Seq
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
from ....conftest import HfRunner, PromptAudioInput, VllmRunner
|
||||
from ....utils import create_new_process_for_each_test, multi_gpu_test
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
PROMPTS = [
|
||||
{
|
||||
"prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("winning_call").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
|
||||
},
|
||||
]
|
||||
VLLM_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
|
||||
HF_PROMPT = ""
|
||||
# Whisper expects 16kHz audio
|
||||
WHISPER_SAMPLE_RATE = 16000
|
||||
|
||||
EXPECTED = {
|
||||
"openai/whisper-tiny": [
|
||||
" He has birth words I spoke in the original corner of that. And a"
|
||||
" little piece of black coat poetry. Mary had a little sandwich,"
|
||||
" sweet, with white and snow. And everyone had it very went the last"
|
||||
" would sure to go.",
|
||||
" >> And the old one, fit John the way to Edgar Martinez. >> One more"
|
||||
" to line down the field line for our base camp. Here comes joy. Here"
|
||||
" is June and the third base. They're going to wave him in. The throw"
|
||||
" to the plate will be late. The Mariners are going to play for the"
|
||||
" American League Championship. I don't believe it. It just continues"
|
||||
" by all five.",
|
||||
],
|
||||
"openai/whisper-small": [
|
||||
" The first words I spoke in the original pornograph. A little piece"
|
||||
" of practical poetry. Mary had a little lamb, its fleece was quite a"
|
||||
" slow, and everywhere that Mary went the lamb was sure to go.",
|
||||
" And the old one pitch on the way to Edgar Martinez one month. Here"
|
||||
" comes joy. Here is Junior to third base. They're gonna wave him"
|
||||
" in. The throw to the plate will be late. The Mariners are going to"
|
||||
" play for the American League Championship. I don't believe it. It"
|
||||
" just continues. My, oh my.",
|
||||
],
|
||||
"openai/whisper-medium": [
|
||||
" The first words I spoke in the original phonograph, a little piece"
|
||||
" of practical poetry. Mary had a little lamb, its fleece was quite as"
|
||||
" slow, and everywhere that Mary went the lamb was sure to go.",
|
||||
" And the 0-1 pitch on the way to Edgar Martinez swung on the line"
|
||||
" down the left field line for Obeyshev. Here comes Joy. Here is"
|
||||
" Jorgen at third base. They're going to wave him in. The throw to the"
|
||||
" plate will be late. The Mariners are going to play for the American"
|
||||
" League Championship. I don't believe it. It just continues. My, oh"
|
||||
" my.",
|
||||
],
|
||||
"openai/whisper-large-v3": [
|
||||
" The first words I spoke in the original phonograph, a little piece"
|
||||
" of practical poetry. Mary had a little lamb, its feet were quite as"
|
||||
" slow, and everywhere that Mary went, the lamb was sure to go.",
|
||||
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line."
|
||||
" Now the left field line for a base hit. Here comes Joy. Here is"
|
||||
" Junior to third base. They're going to wave him in. The throw to the"
|
||||
" plate will be late. The Mariners are going to play for the American"
|
||||
" League Championship. I don't believe it. It just continues. My, oh,"
|
||||
" my.",
|
||||
],
|
||||
"openai/whisper-large-v3-turbo": [
|
||||
" The first words I spoke in the original phonograph, a little piece"
|
||||
" of practical poetry. Mary had a little lamb, its streets were quite"
|
||||
" as slow, and everywhere that Mary went the lamb was sure to go.",
|
||||
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line"
|
||||
" down the left field line for a base hit. Here comes Joy. Here is"
|
||||
" Junior to third base. They're going to wave him in. The throw to the"
|
||||
" plate will be late. The Mariners are going to play for the American"
|
||||
" League Championship. I don't believe it. It just continues. My, oh,"
|
||||
" my.",
|
||||
],
|
||||
}
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def use_spawn_for_whisper(monkeypatch):
|
||||
"""Whisper has issues with forked workers, use spawn instead."""
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
inputs: Sequence[tuple[list[str], list[str], PromptAudioInput]],
|
||||
model: str,
|
||||
*,
|
||||
max_model_len: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
dtype: str = "half",
|
||||
enforce_eager: bool = True,
|
||||
) -> None:
|
||||
prompt_list = PROMPTS * 10
|
||||
expected_list = EXPECTED[model] * 10
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the audio fixtures for the test are from AudioAsset.
|
||||
For huggingface runner, we provide the audio as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
"""
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
max_model_len=max_model_len,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
# TODO (NickLucche) figure out output differences with non-eager and re-enable
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"audio": 2},
|
||||
enforce_eager=enforce_eager,
|
||||
disable_custom_all_reduce=True,
|
||||
) as vllm_model:
|
||||
llm = vllm_model.llm
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
vllm_prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=audios,
|
||||
)
|
||||
for vllm_prompts, _, audios in inputs
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
max_tokens=200,
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
hf_prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=audios,
|
||||
)
|
||||
for _, hf_prompts, audios in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompt_list, sampling_params)
|
||||
|
||||
for output, expected in zip(outputs, expected_list):
|
||||
print(output.outputs[0].text)
|
||||
assert output.outputs[0].text == expected
|
||||
@pytest.fixture
|
||||
def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
inputs = []
|
||||
for asset in audio_assets:
|
||||
audio, orig_sr = asset.audio_and_sample_rate
|
||||
# Resample to Whisper's expected sample rate (16kHz)
|
||||
if orig_sr != WHISPER_SAMPLE_RATE:
|
||||
audio = librosa.resample(
|
||||
audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
|
||||
)
|
||||
# vLLM prompts, HF prompts, audio inputs
|
||||
inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)]))
|
||||
return inputs
|
||||
|
||||
|
||||
def check_model_available(model: str) -> None:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@create_new_process_for_each_test()
|
||||
def test_models(vllm_runner, model, dtype) -> None:
|
||||
run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||
# @create_new_process_for_each_test() does not work for some runners
|
||||
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
num_logprobs: int,
|
||||
input_audios,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
check_model_available(model)
|
||||
if current_platform.is_cpu() and not enforce_eager:
|
||||
pytest.skip("Skipping test for CPU with non-eager mode")
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_audios,
|
||||
model,
|
||||
tensor_parallel_size=1,
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
max_tokens=200,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
|
||||
@ -152,15 +148,31 @@ def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [200])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_models_distributed(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
distributed_executor_backend,
|
||||
model: str,
|
||||
distributed_executor_backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
input_audios,
|
||||
) -> None:
|
||||
check_model_available(model)
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_audios,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=448,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=False,
|
||||
)
|
||||
|
||||
125
tests/models/multimodal/processing/test_audioflamingo3.py
Normal file
125
tests/models/multimodal/processing/test_audioflamingo3.py
Normal file
@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
class MockAudioFlamingo3Config(PretrainedConfig):
|
||||
model_type = "audioflamingo3"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.audio_config = PretrainedConfig()
|
||||
self.text_config = PretrainedConfig()
|
||||
|
||||
|
||||
class MockAudioFlamingo3Processor:
|
||||
def __init__(self):
|
||||
self.audio_token = "<sound>"
|
||||
self.audio_token_id = 12345
|
||||
self.feature_extractor = MockFeatureExtractor()
|
||||
|
||||
def __call__(self, text=None, audios=None, **kwargs):
|
||||
return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]}
|
||||
|
||||
|
||||
class MockFeatureExtractor:
|
||||
def __init__(self):
|
||||
self.sampling_rate = 16000
|
||||
self.chunk_length = 30
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
config = MockAudioFlamingo3Config()
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.get_hf_config.return_value = config
|
||||
ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor()
|
||||
ctx.model_config.hf_config = config
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_transformers_version():
|
||||
# Check if the model is supported by the current transformers version
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
def test_audio_chunk_counting(mock_ctx):
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
AudioFlamingo3ProcessingInfo,
|
||||
)
|
||||
|
||||
info = AudioFlamingo3ProcessingInfo(mock_ctx)
|
||||
processor = AudioFlamingo3MultiModalProcessor(
|
||||
info, AudioFlamingo3DummyInputsBuilder(info)
|
||||
)
|
||||
|
||||
sr = 16000
|
||||
audio_1 = np.zeros(30 * sr)
|
||||
audio_2 = np.zeros(45 * sr)
|
||||
|
||||
mm_data = {"audio": [audio_1, audio_2]}
|
||||
prompt = "<|user|>Listen.<|end|>"
|
||||
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs):
|
||||
return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)}
|
||||
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call)
|
||||
|
||||
processed = processor._call_hf_processor(prompt, mm_data, {}, {})
|
||||
|
||||
chunk_counts = processed["chunk_counts"]
|
||||
|
||||
assert chunk_counts[0].item() == 1
|
||||
assert chunk_counts[1].item() == 2
|
||||
assert len(chunk_counts) == 2
|
||||
|
||||
|
||||
def test_dummy_data_generation(mock_ctx):
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3ProcessingInfo,
|
||||
)
|
||||
|
||||
info = AudioFlamingo3ProcessingInfo(mock_ctx)
|
||||
builder = AudioFlamingo3DummyInputsBuilder(info)
|
||||
|
||||
mm_counts = {"audio": 2}
|
||||
dummy_data = builder.get_dummy_mm_data(100, mm_counts, None)
|
||||
|
||||
assert "audio" in dummy_data
|
||||
assert len(dummy_data["audio"]) == 2
|
||||
|
||||
expected_len = 600 * 16000
|
||||
assert len(dummy_data["audio"][0]) == expected_len
|
||||
@ -22,11 +22,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.tokenizers import (
|
||||
MistralTokenizer,
|
||||
TokenizerLike,
|
||||
cached_tokenizer_from_config,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import (
|
||||
|
||||
42
tests/models/multimodal/processing/test_gemma3.py
Normal file
42
tests/models/multimodal/processing/test_gemma3.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
|
||||
def test_get_image_size_with_most_features(
|
||||
image_assets: ImageTestAssets, model_id: str
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
hf_processor_mm_kwargs: dict[str, object] = {}
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
max_image_size = processor.info.get_image_size_with_most_features()
|
||||
max_tokens = processor.info.get_num_image_tokens(
|
||||
image_width=max_image_size.width,
|
||||
image_height=max_image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
prompt = "<start_of_image>"
|
||||
image_seq_length = hf_processor.image_seq_length
|
||||
|
||||
for asset in image_assets:
|
||||
mm_data = {"image": [asset.pil_image]}
|
||||
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
mm_kwargs_data = processed_inputs["mm_kwargs"].get_data()
|
||||
num_patches_tensor = mm_kwargs_data["num_patches"]
|
||||
tokens = int(num_patches_tensor.item()) * image_seq_length
|
||||
assert tokens <= max_tokens
|
||||
@ -53,3 +53,38 @@ def test_processor_override(
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
|
||||
assert pixel_shape[1] == expected_pixels_shape[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
||||
@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28])
|
||||
def test_get_image_size_with_most_features(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
max_pixels: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={"max_pixels": max_pixels},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
hf_processor_mm_kwargs: dict[str, object] = {}
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size
|
||||
|
||||
max_image_size = processor.info.get_image_size_with_most_features()
|
||||
max_tokens = processor.info.get_num_image_tokens(
|
||||
image_width=max_image_size.width,
|
||||
image_height=max_image_size.height,
|
||||
image_processor=hf_processor.image_processor,
|
||||
)
|
||||
|
||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
for asset in image_assets:
|
||||
mm_data = {"image": [asset.pil_image]}
|
||||
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist()
|
||||
t, h, w = grid_thw[0]
|
||||
tokens = (t * h * w) // (merge_size**2)
|
||||
assert tokens < max_tokens
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from ....utils import create_new_process_for_each_test
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
from .test_common import get_model_ids_to_test, get_text_token_prompts
|
||||
@ -136,6 +138,7 @@ def create_batched_mm_kwargs(
|
||||
)
|
||||
|
||||
|
||||
# TODO(Isotr0py): Don't initalize model during test
|
||||
@contextmanager
|
||||
def initialize_dummy_model(
|
||||
model_cls: type[nn.Module],
|
||||
@ -150,16 +153,21 @@ def initialize_dummy_model(
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
|
||||
current_device = torch.get_default_device()
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
with set_current_vllm_config(vllm_config=vllm_config):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
model = model_cls(vllm_config=vllm_config)
|
||||
torch.set_default_device(current_device)
|
||||
yield model
|
||||
|
||||
del model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_id", get_model_ids_to_test())
|
||||
def test_model_tensor_schema(model_id: str):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
|
||||
@ -173,10 +173,7 @@ class _HfExamplesInfo:
|
||||
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AfmoeForCausalLM": _HfExamplesInfo(
|
||||
"arcee-ai/Trinity-Nano",
|
||||
is_available_online=False,
|
||||
),
|
||||
"AfmoeForCausalLM": _HfExamplesInfo("arcee-ai/Trinity-Nano-Preview"),
|
||||
"ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"),
|
||||
"AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True),
|
||||
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True),
|
||||
@ -359,7 +356,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
),
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", is_available_online=False
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"
|
||||
),
|
||||
"MixtralForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
@ -576,12 +573,17 @@ _AUTOMATIC_CONVERTED_MODELS = {
|
||||
"Qwen3ForSequenceClassification": _HfExamplesInfo(
|
||||
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
||||
),
|
||||
"Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"),
|
||||
}
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
|
||||
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
|
||||
),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
|
||||
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
|
||||
"BeeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Open-Bee/Bee-8B-RL",
|
||||
trust_remote_code=True,
|
||||
@ -638,7 +640,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"HunYuanVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"tencent/HunyuanOCR",
|
||||
is_available_online=False,
|
||||
hf_overrides={"num_experts": 0},
|
||||
),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo(
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
@ -677,8 +679,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31",
|
||||
),
|
||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||
"lightonai/LightOnOCR-1B",
|
||||
is_available_online=False,
|
||||
"lightonai/LightOnOCR-1B-1025"
|
||||
),
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
@ -782,8 +783,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"ministral-3": "mistralai/Ministral-3-3B-Instruct-2512",
|
||||
},
|
||||
tokenizer_mode="mistral",
|
||||
# TODO: revert once Mistral-Large-3 and Ministral-3 are publicly available.
|
||||
is_available_online=False,
|
||||
),
|
||||
"QwenVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"Qwen/Qwen-VL",
|
||||
@ -846,7 +845,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
),
|
||||
# [Encoder-decoder]
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"),
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo(
|
||||
"openai/whisper-large-v3-turbo",
|
||||
extras={"v3": "openai/whisper-large-v3"},
|
||||
),
|
||||
# [Cross-encoder]
|
||||
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"),
|
||||
}
|
||||
@ -889,6 +891,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512",
|
||||
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
|
||||
# TODO: revert once figuring out OOM in CI
|
||||
is_available_online=False,
|
||||
),
|
||||
"LlamaForCausalLMEagle3": _HfExamplesInfo(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user