Merge remote-tracking branch 'refs/remotes/origin/main' into one-pod-per-node-lb

This commit is contained in:
Nick Hill 2025-07-23 11:04:38 +01:00
commit 1bd5f2f1ad
97 changed files with 3490 additions and 754 deletions

View File

@ -225,7 +225,7 @@ steps:
##### 1 GPU test #####
- label: Regression Test # 5min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/test_regression
@ -277,7 +277,7 @@ steps:
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: Examples Test # 25min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
working_dir: "/vllm-workspace/examples"
source_file_dependencies:
- vllm/entrypoints
@ -311,7 +311,7 @@ steps:
- label: Platform Tests (CUDA)
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/cuda
@ -330,7 +330,7 @@ steps:
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
- label: LoRA Test %N # 15min each
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/lora
- tests/lora
@ -382,7 +382,7 @@ steps:
- pytest -v -s kernels/core
- label: Kernels Attention Test %N
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/attention/
- vllm/attention
@ -393,7 +393,7 @@ steps:
parallelism: 2
- label: Kernels Quantization Test %N
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/quantization/
- vllm/model_executor/layers/quantization
@ -412,7 +412,7 @@ steps:
- pytest -v -s kernels/moe
- label: Kernels Mamba Test
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- csrc/mamba/
- tests/kernels/mamba
@ -420,7 +420,7 @@ steps:
- pytest -v -s kernels/mamba
- label: Tensorizer Test # 11min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
soft_fail: true
source_file_dependencies:
- vllm/model_executor/model_loader
@ -434,7 +434,6 @@ steps:
- label: Model Executor Test
mirror_hardwares: [amdexperimental, amdproduction]
soft_fail: true
source_file_dependencies:
- vllm/model_executor
- tests/model_executor
@ -491,7 +490,7 @@ steps:
- pytest -s entrypoints/openai/correctness/
- label: Encoder Decoder tests # 5min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
source_file_dependencies:
- vllm/
- tests/encoder_decoder
@ -499,7 +498,7 @@ steps:
- pytest -v -s encoder_decoder
- label: OpenAI-Compatible Tool Use # 20 min
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
fast_check: false
source_file_dependencies:
- vllm/
@ -611,7 +610,7 @@ steps:
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
- label: Quantized Models Test
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/model_executor/layers/quantization
- tests/models/quantization

View File

@ -296,7 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu")
"csrc/attention/mla/cutlass_mla_entry.cu"
"csrc/quantization/fp8/per_token_group_quant.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@ -577,7 +578,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
@ -595,6 +596,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()
# moe_data.cu is used by all CUTLASS MoE kernels.
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)

View File

@ -126,11 +126,12 @@ run_benchmark() {
# get a basic qps by using request-rate inf
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
python benchmarks/benchmark_serving.py \
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--model $MODEL \
--dataset-name random \
--random-input-len $INPUT_LEN \
--random-input-len $adjusted_input_len \
--random-output-len $OUTPUT_LEN \
--ignore-eos \
--disable-tqdm \
@ -159,11 +160,11 @@ run_benchmark() {
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
sleep 5
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
python benchmarks/benchmark_serving.py \
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--model $MODEL \
--dataset-name random \
--random-input-len $INPUT_LEN \
--random-input-len $adjusted_input_len \
--random-output-len $OUTPUT_LEN \
--ignore-eos \
--disable-tqdm \

View File

@ -30,7 +30,7 @@ import os
import random
import time
import warnings
from collections.abc import AsyncGenerator, Iterable
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Literal, Optional
@ -73,6 +73,7 @@ from benchmark_dataset import (
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.benchmarks.serve import get_request
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -107,101 +108,6 @@ class BenchmarkMetrics:
percentiles_e2el_ms: list[tuple[float, float]]
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
ramp_up_start_rps: Optional[int],
ramp_up_end_rps: Optional[int],
request_index: int,
total_requests: int,
request_rate: float,
) -> float:
if (
ramp_up_strategy
and ramp_up_start_rps is not None
and ramp_up_end_rps is not None
):
progress = request_index / max(total_requests - 1, 1)
if ramp_up_strategy == "linear":
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
return ramp_up_start_rps + increase
elif ramp_up_strategy == "exponential":
ratio = ramp_up_end_rps / ramp_up_start_rps
return ramp_up_start_rps * (ratio**progress)
else:
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
return request_rate
async def get_request(
input_requests: list[SampleRequest],
request_rate: float,
burstiness: float = 1.0,
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
"""
Asynchronously generates requests at a specified rate
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
Args:
input_requests:
A list of input requests, each represented as a SampleRequest.
request_rate:
The rate at which requests are generated (requests/s).
burstiness (optional):
The burstiness factor of the request generation.
Only takes effect when request_rate is not inf.
Default value is 1, which follows a Poisson process.
Otherwise, the request intervals follow a gamma distribution.
A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests.
ramp_up_strategy (optional):
The ramp-up strategy. Can be "linear" or "exponential".
If None, uses constant request rate (specified by request_rate).
ramp_up_start_rps (optional):
The starting request rate for ramp-up.
ramp_up_end_rps (optional):
The ending request rate for ramp-up.
"""
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}."
)
# Convert to list to get length for ramp-up calculations
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
input_requests = list(input_requests)
total_requests = len(input_requests)
request_index = 0
for request in input_requests:
current_request_rate = _get_current_request_rate(
ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate,
)
yield request, current_request_rate
request_index += 1
if current_request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
theta = 1.0 / (current_request_rate * burstiness)
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
interval = np.random.gamma(shape=burstiness, scale=theta)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],

View File

@ -80,11 +80,6 @@ def bench_run(
a, score, topk, renormalize=False
)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@ -116,10 +111,6 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@ -134,10 +125,6 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
@ -149,10 +136,6 @@ def bench_run(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@ -167,10 +150,6 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
@ -215,10 +194,6 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
@ -256,10 +231,6 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@ -318,10 +289,6 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
@ -330,7 +297,7 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,

View File

@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
sorted_ids_triton = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
)
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
expert_ids_triton = torch.zeros(
expert_ids_triton = torch.empty(
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
)
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
expert_ids_vllm = torch.empty_like(expert_ids_triton)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
# 2. run implementations
@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")

View File

@ -15,15 +15,16 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
const float x = (float)input[blockIdx.x * input_stride + idx];
variance += x * x;
}
@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
const int64_t vec_input_stride = input_stride / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = input_v[strided_id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
input_v[strided_id] = temp;
}
}
@ -103,7 +108,8 @@ fused_add_rms_norm_kernel(
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
@ -111,7 +117,7 @@ fused_add_rms_norm_kernel(
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
scalar_t z = input[blockIdx.x * input_stride + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] =
input[blockIdx.x * input_stride + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), input_stride, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
});
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int64_t input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
constexpr int vector_width = 8;
constexpr int req_alignment_bytes =
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
res_ptr % req_alignment_bytes == 0 &&
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);

View File

@ -23,8 +23,9 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int input_stride,
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
const float x = (float)input[blockIdx.x * input_stride + idx];
variance += x * x;
}
@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x = (float)input[blockIdx.x * input_stride + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel(
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
const int input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
const int vec_input_stride = input_stride / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int stride_id = blockIdx.x * vec_input_stride + idx;
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
_f16Vec<scalar_t, width> temp = input_v[stride_id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
const int input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
scalar_t z = input[blockIdx.x * input_stride + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
epsilon, num_tokens, hidden_size);
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
});
});
}
@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
input_stride, residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);

View File

@ -1,6 +1,7 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum) {
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) {
extern __shared__ int32_t shared_counts[];
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[it] = numel;
}
const int warp_id = threadIdx.x / WARP_SIZE;
const int my_expert_start = warp_id * experts_per_warp;
@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
cumsum[i] =
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
int expert_count = 0;
int expert_id = threadIdx.x;
if (expert_id < num_experts) {
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
expert_count = CEILDIV(expert_count, block_size) * block_size;
}
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
if (expert_id <= num_experts) {
cumsum[expert_id] = cumsum_val;
}
if (expert_id == num_experts) {
*total_tokens_post_pad = cumsum_val;
}
__syncthreads();
@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
expert_ids[i / block_size] = threadIdx.x;
}
}
// Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
expert_ids[i] = 0;
}
}
template <typename scalar_t>
@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel) {
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[it] = numel;
}
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
// Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
expert_ids[i] = 0;
}
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad =
@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int threads = 1024;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK(padded_num_experts < 1024,
"padded_num_experts must be less than 1024");
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
torch::Tensor cumsum_buffer =
torch::zeros({num_experts + 1}, options_int);
torch::empty({num_experts + 1}, options_int);
bool small_batch_expert_mode =
(topk_ids.numel() < 1024) && (num_experts <= 64);
@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
topk_ids.numel(), sorted_token_ids.size(0));
} else {
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
padded_num_experts, experts_per_warp, block_size,
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
sorted_token_ids.size(0));
const int block_threads = std::min(256, (int)threads);
const int num_blocks =

View File

@ -160,30 +160,6 @@ __global__ void shuffleInputRowsKernel(const T* input,
}
}
template <typename T>
__global__ void shuffleInputRowsKernelSlow(const T* input,
const int32_t* dst2src_map,
T* output, int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Duplicate and permute rows
auto const* source_row_ptr = input + source_row_idx * num_cols;
auto* dest_row_ptr = output + dest_row_idx * num_cols;
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
for (int elem_index = start_offset; elem_index < num_cols;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
@ -197,24 +173,17 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
} else {
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
#else

View File

@ -287,6 +287,11 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,

View File

@ -18,7 +18,6 @@ using ProblemShape =
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using LayoutA = cutlass::layout::RowMajor;
@ -33,7 +32,7 @@ using LayoutD_Transpose =
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
template <typename ElementAB_, typename ElementC_,
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, bool swap_ab_ = false>
@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm {
using ElementC = void;
using ElementD = ElementC_;
using ElementAccumulator = float;
using ArchTag = ArchTag_;
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm {
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
Stages, KernelSchedule>::CollectiveOp>;
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
@ -156,9 +156,14 @@ void cutlass_group_gemm_caller(
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
int device_id = a_tensors.device().index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
epilogue_args};
epilogue_args, hw_info};
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;

View File

@ -0,0 +1,140 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
using namespace cute;
namespace {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in [1,64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_N8192 {
// N in [8192, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
uint32_t const m = a_tensors.size(0);
uint32_t const n = out_tensors.size(1);
if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
} // namespace
void dispatch_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
}

View File

@ -21,10 +21,11 @@ struct sm90_fp8_config_default {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
@ -38,10 +39,12 @@ struct sm90_fp8_config_M4 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
@ -55,10 +58,12 @@ struct sm90_fp8_config_M64 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
@ -72,10 +77,11 @@ struct sm90_fp8_config_K8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
@ -89,10 +95,11 @@ struct sm90_fp8_config_N8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
@ -112,9 +119,6 @@ void run_cutlass_moe_mm_sm90(
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<

View File

@ -190,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
}

View File

@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90(
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION
if (cuda_device_capability == 90) {
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
@ -234,16 +247,26 @@ void cutlass_moe_mm(
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
if (version_num >= 90) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90");
". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_data(

View File

@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();

View File

@ -0,0 +1,213 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <torch/all.h>
#include "../vectorization.cuh"
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
return val;
}
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
scale_packed_t* __restrict__ output_s, const int group_size,
const int num_groups, const int groups_per_block, const float eps,
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = float;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
scale_element_t* scale_output;
if constexpr (IS_COLUMN_MAJOR) {
const int num_elems_per_pack =
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
const int row_idx = global_group_id / scale_num_rows_element;
const int col_idx_raw = global_group_id % scale_num_rows_element;
const int col_idx = col_idx_raw / num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack +
row_idx * num_elems_per_pack + pack_idx);
} else {
scale_output = output_s + global_group_id;
}
// shared memory to cache each group's data to avoid double DRAM reads.
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax, lane_id);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
scale_element_t y_s_quant = y_s;
if (lane_id == 0) {
*scale_output = y_s_quant;
}
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
const int num_groups = input.numel() / group_size;
TORCH_CHECK(input.numel() % group_size == 0);
TORCH_CHECK(output_s.dim() == 2);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1);
const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
size_t smem_bytes = \
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
if (is_column_major) { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} \
} else { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} \
} \
} while (0)
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
}
}));
#undef LAUNCH_KERNEL
}
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}

View File

@ -615,6 +615,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
#ifndef USE_ROCM
// Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0) -> ()");
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "

View File

@ -1,10 +1,10 @@
# Tool Calling
vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`) and `none` options for the `tool_choice` field in the chat completion API.
vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`), and `none` options for the `tool_choice` field in the chat completion API.
## Quickstart
Start the server with tool calling enabled. This example uses Meta's Llama 3.1 8B model, so we need to use the llama3 tool calling chat template from the vLLM examples directory:
Start the server with tool calling enabled. This example uses Meta's Llama 3.1 8B model, so we need to use the `llama3_json` tool calling chat template from the vLLM examples directory:
```bash
vllm serve meta-llama/Llama-3.1-8B-Instruct \
@ -13,7 +13,7 @@ vllm serve meta-llama/Llama-3.1-8B-Instruct \
--chat-template examples/tool_chat_template_llama3.1_json.jinja
```
Next, make a request to the model that should result in it using the available tools:
Next, make a request that triggers the model to use the available tools:
??? code
@ -73,7 +73,7 @@ This example demonstrates:
You can also specify a particular function using named function calling by setting `tool_choice={"type": "function", "function": {"name": "get_weather"}}`. Note that this will use the guided decoding backend - so the first time this is used, there will be several seconds of latency (or more) as the FSM is compiled for the first time before it is cached for subsequent requests.
Remember that it's the callers responsibility to:
Remember that it's the caller's responsibility to:
1. Define appropriate tools in the request
2. Include relevant context in the chat messages
@ -84,7 +84,7 @@ For more advanced usage, including parallel tool calls and different model-speci
## Named Function Calling
vLLM supports named function calling in the chat completion API by default. It does so using Outlines through guided decoding, so this is
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
enabled by default and will work with any supported model. You are guaranteed a validly-parsable function call - not a
high-quality one.
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
@ -95,7 +95,7 @@ specify the `name` of one of the tools in the `tool_choice` parameter of the cha
## Required Function Calling
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](../usage/v1_guide.md#features) for the V1 engine.
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The guided decoding features for `tool_choice='required'` (such as JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](../usage/v1_guide.md#features) for the V1 engine.
When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter.
@ -109,16 +109,16 @@ However, when `tool_choice='none'` is specified, vLLM includes tool definitions
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. It tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
will continue to be added in the future. You can also register your own tool parsers in the `--tool-parser-plugin`.
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
* `--chat-template` -- **optional** for auto tool choice. It's the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates)
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json)
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json).
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
@ -130,7 +130,7 @@ All Nous Research Hermes-series models newer than Hermes 2 Pro should be support
* `NousResearch/Hermes-2-Theta-*`
* `NousResearch/Hermes-3-*`
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality and capabilities due to the merge
step in their creation_.
Flags: `--tool-call-parser hermes`
@ -146,13 +146,13 @@ Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
* <gh-file:examples/tool_chat_template_mistral.jinja> - this is the "official" Mistral chat template, but tweaked so that
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.
* <gh-file:examples/tool_chat_template_mistral.jinja> - this is the "official" Mistral chat template, but tweaked so that
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
@ -166,17 +166,17 @@ All Llama 3.1, 3.2 and 4 models should be supported.
* `meta-llama/Llama-3.2-*`
* `meta-llama/Llama-4-*`
The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for llama 4 models, it is recommended to use the `llama4_pythonic` tool parser.
The tool calling that is supported is the [JSON-based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for Llama 4 models, it is recommended to use the `llama4_pythonic` tool parser.
Other tool calling formats like the built in python tool calling or custom tool calling are not supported.
Known issues:
1. Parallel tool calls are not supported for llama 3, but it is supported in llama 4 models.
2. The model can generate parameters with a wrong format, such as generating
1. Parallel tool calls are not supported for Llama 3, but it is supported in Llama 4 models.
2. The model can generate parameters in an incorrect format, such as generating
an array serialized as string instead of an array.
VLLM provides two JSON based chat templates for Llama 3.1 and 3.2:
VLLM provides two JSON-based chat templates for Llama 3.1 and 3.2:
* <gh-file:examples/tool_chat_template_llama3.1_json.jinja> - this is the "official" chat template for the Llama 3.1
models, but tweaked so that it works better with vLLM.
@ -185,7 +185,8 @@ images.
Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}`
VLLM also provides a pythonic and JSON based chat template for Llama 4, but pythonic tool calling is recommended:
VLLM also provides a pythonic and JSON-based chat template for Llama 4, but pythonic tool calling is recommended:
* <gh-file:examples/tool_chat_template_llama4_pythonic.jinja> - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models.
For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`.
@ -196,21 +197,21 @@ Supported models:
* `ibm-granite/granite-3.0-8b-instruct`
Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja`
Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja`
<gh-file:examples/tool_chat_template_granite.jinja>: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.
<gh-file:examples/tool_chat_template_granite.jinja>: this is a modified chat template from the original on Hugging Face. Parallel function calls are supported.
* `ibm-granite/granite-3.1-8b-instruct`
Recommended flags: `--tool-call-parser granite`
Recommended flags: `--tool-call-parser granite`
The chat template from Huggingface can be used directly. Parallel function calls are supported.
The chat template from Huggingface can be used directly. Parallel function calls are supported.
* `ibm-granite/granite-20b-functioncalling`
Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`
Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`
<gh-file:examples/tool_chat_template_granite_20b_fc.jinja>: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
<gh-file:examples/tool_chat_template_granite_20b_fc.jinja>: this is a modified chat template from the original on Hugging Face, which is not vLLM-compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
### InternLM Models (`internlm`)
@ -246,10 +247,12 @@ The xLAM tool parser is designed to support models that generate tool calls in v
Parallel function calls are supported, and the parser can effectively separate text content from tool calls.
Supported models:
* Salesforce Llama-xLAM models: `Salesforce/Llama-xLAM-2-8B-fc-r`, `Salesforce/Llama-xLAM-2-70B-fc-r`
* Qwen-xLAM models: `Salesforce/xLAM-1B-fc-r`, `Salesforce/xLAM-3B-fc-r`, `Salesforce/Qwen-xLAM-32B-fc-r`
Flags:
* For Llama-based xLAM models: `--tool-call-parser xlam --chat-template examples/tool_chat_template_xlam_llama.jinja`
* For Qwen-based xLAM models: `--tool-call-parser xlam --chat-template examples/tool_chat_template_xlam_qwen.jinja`
@ -292,9 +295,10 @@ Flags: `--tool-call-parser kimi_k2`
Supported models:
* `tencent/Hunyuan-A13B-Instruct` (chat template already included huggingface model file.)
* `tencent/Hunyuan-A13B-Instruct` (The chat template is already included in the Hugging Face model files.)
Flags:
* For non-reasoning: `--tool-call-parser hunyuan_a13b`
* For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning`
@ -325,9 +329,9 @@ Example supported models:
Flags: `--tool-call-parser pythonic --chat-template {see_above}`
!!! warning
Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary.
Llama's smaller models frequently fail to emit tool calls in the correct format. Results may vary depending on the model.
## How to write a tool parser plugin
## How to Write a Tool Parser Plugin
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in <gh-file:vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py>.

View File

@ -168,17 +168,18 @@ Note, it is recommended to manually reserve 1 CPU for vLLM front-end process whe
### How to do performance tuning for vLLM CPU?
- First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
- Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
- Offline Inference: `4096 * world_size`
- Online Serving: `2048 * world_size`
- `--max-num-seqs`, defines the limit of sequence numbers in a single batch, has more impacts on the output token performance.
- Offline Inference: `256 * world_size`
- Online Serving: `128 * world_size`
Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
- vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes.
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
- Offline Inference: `4096 * world_size`
- Online Serving: `2048 * world_size`
- `--max-num-seqs`, defines the limit of sequence numbers in a single batch, has more impacts on the output token performance.
- Offline Inference: `256 * world_size`
- Online Serving: `128 * world_size`
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes.
### Which quantization configs does vLLM CPU support?

View File

@ -18,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models]
### Transformers
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs, and require setting `--disable_mm_preprocessor_cache` when running. Support for video inputs and caching of multi-modal preprocessors will be added in future releases.
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases.
To check if the modeling backend is Transformers, you can simply do this:
@ -324,6 +324,7 @@ th {
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -37,7 +37,6 @@ def initialize_llm():
max_num_seqs=4,
max_model_len=2048,
block_size=2048,
use_v2_block_manager=True,
device="neuron",
tensor_parallel_size=32,
)

View File

@ -659,7 +659,8 @@ setup(
"bench": ["pandas", "datasets"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"runai":
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing
"video": [] # Kept for backwards compatibility

View File

@ -177,7 +177,7 @@ TEXT_GENERATION_MODELS = {
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama
@ -249,7 +249,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"ArthurZ/Ilama-3.2-1B",
"hmellor/Ilama-3.2-1B",
"ibm/PowerLM-3b",
"deepseek-ai/DeepSeek-V2-Lite-Chat",
# [LANGUAGE EMBEDDING]

View File

@ -26,6 +26,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
@ -34,13 +35,17 @@ def test_rms_norm(
dtype: torch.dtype,
seed: int,
device: str,
strided_input: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
last_dim = 2 * hidden_size if strided_input else hidden_size
x = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
@ -72,6 +77,7 @@ def test_rms_norm(
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
def test_fused_rms_norm_quant(
num_tokens: int,
hidden_size: int,
@ -80,13 +86,18 @@ def test_fused_rms_norm_quant(
quant_scale: float,
seed: int,
device: str,
strided_input: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
last_dim = 2 * hidden_size if strided_input else hidden_size
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x_base[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
if add_residual:
residual = torch.randn_like(x) * scale
@ -106,9 +117,11 @@ def test_fused_rms_norm_quant(
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused = x.clone()
x_unfused_base = x_base.clone()
x_unfused = x_unfused_base[..., :hidden_size]
assert x_unfused.is_contiguous() != strided_input
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
quant_scale_t)
torch.cuda.synchronize()
@ -116,7 +129,6 @@ def test_fused_rms_norm_quant(
residual,
atol=1e-2,
rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
@ -131,7 +143,7 @@ def test_fused_rms_norm_quant(
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
out_quant.to(dtype=torch.float32),
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
out_quant_fused.to(dtype=torch.float32),
atol=1e-3,
rtol=1e-3)

View File

@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel(
gate_states[..., local_rank * N:(local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.allclose(output,
ref_output[..., local_rank * N:(local_rank + 1) * N],
atol=1e-3,
rtol=1e-3)
torch.testing.assert_close(output,
ref_output[...,
local_rank * N:(local_rank + 1) * N],
atol=5e-3,
rtol=1e-3)

View File

@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if itype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
else:
atol, rtol = 8e-3, 5e-3
# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
return_final_states=True)
# just test the last in sequence
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.allclose(final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=1e-3,
rtol=1e-3)
torch.testing.assert_close(final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# TODO: the irregular chunk size cases have some issues and require higher
# tolerance. This is to be invesigated
if chunk_size not in {8, 256}:
atol, rtol = 5e-1, 5e-1
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states
states = new_states

View File

@ -207,10 +207,6 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
@ -444,11 +440,6 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn,
@ -457,9 +448,8 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False)
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
per_act_token, per_out_channel, False)
workspace13.random_()
output_random_workspace = torch.empty(output_shape,

View File

@ -93,11 +93,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
a1_gscale=a1_gs,
w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
w1_alphas=(1 / w1_gs),
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
g2_alphas=(1 / w2_gs),
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,

View File

@ -75,7 +75,6 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
@ -124,31 +123,10 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=True)

View File

@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.quantization.utils import fp8_utils
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
@pytest.mark.parametrize("column_major", [False, True])
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8(shape, column_major: bool,
scale_ue8m0: bool, group_size: int):
device = "cuda"
torch.manual_seed(42)
num_tokens, hidden_dim = shape
x = (torch.randn(
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
# cuda path
out_q, scale = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major,
use_ue8m0=scale_ue8m0,
)
# triton ref
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major,
use_ue8m0=scale_ue8m0,
)
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)

View File

@ -9,7 +9,7 @@ from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
MODEL_PATH = "hmellor/Ilama-3.2-1B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

View File

@ -5,7 +5,8 @@ import os
import pytest
from vllm.model_executor.layers.pooler import CLSPool, MeanPool, PoolingType
from vllm.model_executor.layers.pooler import (CLSPool, DispatchPooler,
MeanPool, PoolingType)
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform
@ -49,7 +50,8 @@ def test_model_loading_with_params(vllm_runner):
def check_model(model):
assert isinstance(model, BertEmbeddingModel)
assert isinstance(model.pooler.pooling, CLSPool)
assert isinstance(pooler := model.pooler, DispatchPooler)
assert isinstance(pooler.poolers_by_task["embed"].pooling, CLSPool)
vllm_model.apply_model(check_model)
@ -87,7 +89,9 @@ def test_roberta_model_loading_with_params(vllm_runner):
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert isinstance(model.pooler.pooling, MeanPool)
assert isinstance(pooler := model.pooler, DispatchPooler)
assert isinstance(pooler.poolers_by_task["embed"].pooling,
MeanPool)
vllm_model.apply_model(check_model)
@ -114,7 +118,8 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert not hasattr(model, "lm_head")
assert isinstance(model.pooler.pooling, CLSPool)
assert isinstance(pooler := model.pooler, DispatchPooler)
assert isinstance(pooler.poolers_by_task["embed"].pooling, CLSPool)
vllm_model.apply_model(check_model)

View File

@ -15,13 +15,13 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
load_format="dummy",
) as llm:
if model == "google/gemma-3-4b-it":
normalizers = llm.model.collective_rpc(
normalizers = llm.llm.collective_rpc(
lambda self: self.model_runner.model.language_model.model.
normalizer.cpu().item())
config = llm.model.llm_engine.model_config.hf_config.text_config
config = llm.llm.llm_engine.model_config.hf_config.text_config
else:
normalizers = llm.model.collective_rpc(
normalizers = llm.llm.collective_rpc(
lambda self: self.model_runner.model.model.normalizer.cpu(
).item())
config = llm.model.llm_engine.model_config.hf_config
config = llm.llm.llm_engine.model_config.hf_config
assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)

View File

@ -186,8 +186,6 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "transformers",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
marks=[pytest.mark.core_model],
),
@ -205,8 +203,6 @@ VLM_TEST_SETTINGS = {
# image_size_factors=[(0.25, 0.5, 1.0)],
# vllm_runner_kwargs={
# "model_impl": "transformers",
# "disable_mm_preprocessor_cache": True,
# "enable_prefix_caching": False,
# },
# marks=[pytest.mark.core_model],
# ),
@ -223,8 +219,6 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(0.25, 0.2, 0.15)],
vllm_runner_kwargs={
"model_impl": "transformers",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
marks=[large_gpu_mark(min_gb=32)],
),
@ -239,8 +233,6 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "auto",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
auto_cls=AutoModelForImageTextToText,
marks=[pytest.mark.core_model],

View File

@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf: disable
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
def test_multimodal_processor(model_id):
model_config = ModelConfig(
model=model_id,
model_impl="transformers",
)
mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config, )
image_pil = ImageAsset('cherry_blossom').pil_image
mm_data = {"image": image_pil}
str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501
str_processed_inputs = mm_processor.apply(
prompt=str_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
ids_prompt = [
151644, 872, 220, 151646, 198, 3838, 374, 279, 2213, 315, 419, 2168,
30, 151645, 151644, 77091, 198
]
ids_processed_inputs = mm_processor.apply(
prompt=ids_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"]

View File

@ -135,6 +135,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B",
trust_remote_code=True),
"ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base",
is_available_online=False),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
@ -165,9 +167,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"Ernie4_5_ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
trust_remote_code=True),
min_transformers_version="4.54"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True),
min_transformers_version="4.54"),
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
@ -441,6 +443,12 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["TarsierForConditionalGeneration"]}), # noqa: E501
"Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501
hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Voxtral-Mini-3B-2507",
min_transformers_version="4.54",
# disable this temporarily until we support HF format
is_available_online=False,
),
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
@ -448,13 +456,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Voxtral-Mini-3B-2507",
tokenizer_mode="mistral",
min_transformers_version="4.54"
),
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
# [Cross-encoder]
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
}
@ -498,7 +500,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
}
_TRANSFORMERS_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
}

View File

@ -56,7 +56,7 @@ def check_implementation(
"model,model_impl",
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
]) # trust_remote_code=True by default
def test_models(
hf_runner: type[HfRunner],

View File

@ -9,7 +9,6 @@ def test_mistral():
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=128,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True

View File

@ -14,7 +14,6 @@ def test_llama_single_lora():
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True,
@ -57,7 +56,6 @@ def test_llama_multiple_lora():
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled":
False,

View File

@ -0,0 +1,618 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser)
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8"
@pytest.fixture(scope="module")
def qwen3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def qwen3_tool_parser(qwen3_tokenizer):
return Qwen3CoderToolParser(qwen3_tokenizer)
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
},
"state": {
"type": "string",
"description":
"The state code"
},
"unit": {
"type": "string",
"enum":
["fahrenheit", "celsius"]
}
},
"required": ["city", "state"]
}
}),
ChatCompletionToolsParam(type="function",
function={
"name": "calculate_area",
"description":
"Calculate area of a shape",
"parameters": {
"type": "object",
"properties": {
"shape": {
"type": "string"
},
"dimensions": {
"type": "object"
},
"precision": {
"type": "integer"
}
}
}
})
]
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
# Qwen3 parser doesn't generate IDs during extraction
assert actual_tool_call.type == "function"
assert (
actual_tool_call.function.name == expected_tool_call.function.name)
assert (json.loads(actual_tool_call.function.arguments) == json.loads(
expected_tool_call.function.arguments))
def stream_delta_message_generator(
qwen3_tool_parser: Qwen3CoderToolParser,
qwen3_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
) -> Generator[DeltaMessage, None, None]:
all_token_ids = qwen3_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=qwen3_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
current_text = previous_text + delta_text
delta_message = qwen3_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(qwen3_tool_parser):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool",
"single_tool_with_content",
"single_tool_multiline_param",
"parallel_tools",
"tool_with_typed_params",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], None),
('''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], "Sure! Let me check the weather for you."),
('''<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
</parameter>
<parameter=dimensions>
{"width": 10,
"height": 20}
</parameter>
<parameter=precision>
2
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "rectangle",
"dimensions": {
"width": 10,
"height": 20
},
"precision": 2
})))
], None),
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
], None),
('''Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
</parameter>
<parameter=dimensions>
{"radius": 15.5}
</parameter>
<parameter=precision>
3
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "circle",
"dimensions": {
"radius": 15.5
},
"precision": 3
})))
], "Let me calculate that area for you."),
],
)
def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output,
expected_tool_calls, expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=request)
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools):
"""Test fallback parsing when XML tags are missing"""
model_output = '''<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>'''
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=request)
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert (extracted_tool_calls.tool_calls[0].function.name ==
"get_current_weather")
def test_extract_tool_calls_type_conversion(qwen3_tool_parser):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {
"type": "integer"
},
"float_param": {
"type": "float"
},
"bool_param": {
"type": "boolean"
},
"str_param": {
"type": "string"
},
"obj_param": {
"type": "object"
}
}
}
})
]
model_output = '''<tool_call>
<function=test_types>
<parameter=int_param>
42
</parameter>
<parameter=float_param>
3.14
</parameter>
<parameter=bool_param>
true
</parameter>
<parameter=str_param>
hello world
</parameter>
<parameter=obj_param>
{"key": "value"}
</parameter>
</function>
</tool_call>'''
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser.extract_tool_calls(
model_output, request=request)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["int_param"] == 42
assert args["float_param"] == 3.14
assert args["bool_param"] is True
assert args["str_param"] == "hello world"
assert args["obj_param"] == {"key": "value"}
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool",
"single_tool_with_content",
"parallel_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("This is a test without tools", [], "This is a test without tools"),
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], ""),
('''Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], "Sure! Let me check the weather for you."),
('''<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>
<tool_call>
<function=get_current_weather>
<parameter=city>
Orlando
</parameter>
<parameter=state>
FL
</parameter>
<parameter=unit>
celsius
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "celsius"
})))
], ""),
],
)
def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer,
sample_tools, model_output,
expected_tool_calls, expected_content):
"""Test incremental streaming behavior"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
other_content = ''
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
# Initialize state for new tool
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
# Should only be set once
assert tool_states[idx]["name"] is None
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx][
"arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == expected_content
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
state = tool_states[idx]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
expected_args = json.loads(expected_tool.function.arguments)
assert actual_args == expected_args
def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser,
qwen3_tokenizer,
sample_tools):
"""Test that streaming is truly incremental"""
model_output = '''I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>'''
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
chunks = []
for delta_message in stream_delta_message_generator(
qwen3_tool_parser, qwen3_tokenizer, model_output, request):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) > 3
# First chunk(s) should be content
assert chunks[0].content is not None
assert chunks[0].tool_calls is None or chunks[0].tool_calls == []
# Should have a chunk with tool header (id, name, type)
header_found = False
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name == "get_current_weather")
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].function.arguments:
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
assert parsed_args["city"] == "Dallas"
assert parsed_args["state"] == "TX"

View File

@ -184,6 +184,111 @@ def test_free_kv_cache_block_queue_operations():
assert str(e.value) == "No free blocks available"
def test_free_kv_cache_block_queue_append_n():
# Create an empty FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue([])
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Append 0 block
# fake_head->fake_tail
queue.append_n([])
assert queue.num_free_blocks == 0
assert (queue.fake_free_list_head.next_free_block
is queue.fake_free_list_tail)
assert (queue.fake_free_list_tail.prev_free_block
is queue.fake_free_list_head)
# Append 1 block
# fake_head->b0->fake_tail
queue.append_n(blocks[0:1])
assert queue.num_free_blocks == 1
assert queue.fake_free_list_head.next_free_block is blocks[0]
assert blocks[0].prev_free_block is queue.fake_free_list_head
assert blocks[0].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
# Append 2 blocks
# fake_head->b0->b4->b5->fake_tail
queue.append_n(blocks[4:6])
assert queue.num_free_blocks == 3
assert queue.fake_free_list_head.next_free_block is blocks[0]
assert blocks[0].prev_free_block is queue.fake_free_list_head
assert blocks[0].next_free_block is blocks[4]
assert blocks[4].prev_free_block is blocks[0]
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[5]
# Append 3 blocks
# fake_head->b0->b4->b5->b1->b2->b3->fake_tail
queue.append_n(blocks[1:4])
assert queue.num_free_blocks == 6
assert queue.fake_free_list_head.next_free_block is blocks[0]
assert blocks[0].prev_free_block is queue.fake_free_list_head
assert blocks[0].next_free_block is blocks[4]
assert blocks[4].prev_free_block is blocks[0]
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is blocks[1]
assert blocks[1].prev_free_block is blocks[5]
assert blocks[1].next_free_block is blocks[2]
assert blocks[2].prev_free_block is blocks[1]
assert blocks[2].next_free_block is blocks[3]
assert blocks[3].prev_free_block is blocks[2]
assert blocks[3].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[3]
def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create a empty FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue(
[blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]])
assert queue.num_free_blocks == 6
assert queue.fake_free_list_head.next_free_block is blocks[1]
assert blocks[1].prev_free_block is queue.fake_free_list_head
assert blocks[1].next_free_block is blocks[3]
assert blocks[3].prev_free_block is blocks[1]
assert blocks[3].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[3]
assert blocks[5].next_free_block is blocks[4]
assert blocks[4].prev_free_block is blocks[5]
assert blocks[4].next_free_block is blocks[0]
assert blocks[0].prev_free_block is blocks[4]
assert blocks[0].next_free_block is blocks[2]
assert blocks[2].prev_free_block is blocks[0]
assert blocks[2].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[2]
# Pop 0 block
# fake_head->b1->b3->b5->b4->b0->b2->fake_tail
assert len(queue.popleft_n(0)) == 0
# Pop 1 block
# fake_head->b3->b5->b4->b0->b2->fake_tail
result_blocks = queue.popleft_n(1)
assert len(result_blocks) == 1
assert result_blocks[0] is blocks[1]
for block in result_blocks:
assert block.prev_free_block is None
assert block.next_free_block is None
# Pop 2 blocks
# fake_head->b4->b0->b2->fake_tail
result_blocks = queue.popleft_n(2)
assert len(result_blocks) == 2
assert result_blocks[0] is blocks[3]
assert result_blocks[1] is blocks[5]
for block in result_blocks:
assert block.prev_free_block is None
assert block.next_free_block is None
# Pop 3 blocks
# fake_head->fake_tail
result_blocks = queue.popleft_n(3)
assert len(result_blocks) == 3
assert result_blocks[0] is blocks[4]
assert result_blocks[1] is blocks[0]
assert result_blocks[2] is blocks[2]
for block in result_blocks:
assert block.prev_free_block is None
assert block.next_free_block is None
def test_free_kv_cache_block_queue_get_all_free_blocks():
# Create a list of KVCacheBlock objects
blocks = [KVCacheBlock(block_id=i) for i in range(5)]

View File

@ -1097,6 +1097,73 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None
def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
token_ids=(100, )),
group_id=1000)
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
token_ids=(200, )),
group_id=2000)
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
token_ids=(300, )),
group_id=3000)
block_hashes = [
block_hash0,
block_hash1,
block_hash2,
# block3 had the exact same block_hash as the first block
block_hash0,
]
assert len(pool.blocks) == len(block_hashes)
# Manually add all blocks to cached_blocks
for block, block_hash in zip(pool.blocks, block_hashes):
block.block_hash = block_hash
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
block0, block1, block2, block3 = pool.blocks
assert pool.cached_block_hash_to_block == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
},
block_hash1: {
block1.block_id: block1
},
block_hash2: {
block2.block_id: block2
}
}
# Evict block1
pool._maybe_evict_cached_block(block1)
assert pool.cached_block_hash_to_block == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
}
# Evict block0: block_hash0 entry should NOT be removed, as block3
# also use the same hash
pool._maybe_evict_cached_block(block0)
assert pool.cached_block_hash_to_block == {
block_hash0: {
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
}
# Evict block2
pool._maybe_evict_cached_block(block2)
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
# Evict block3
pool._maybe_evict_cached_block(block3)
assert pool.cached_block_hash_to_block == {}
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def test_kv_cache_events(blocks_to_cache: int):
block_size = 16

View File

@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams
from vllm.config import LogprobsMode
from ...conftest import HfRunner, VllmRunner
@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
"facebook/opt-125m",
max_logprobs=5,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode)
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"],
sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0
for output in results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if "logprobs" in logprobs_mode:
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if "logits" in logprobs_mode:
assert positive_values > 0
del llm

View File

@ -3,15 +3,19 @@
import random
import numpy as np
import pytest
import torch
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.utils import GiB_bytes, update_environment_variables
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
@ -686,3 +690,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
'''
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
(via _reshape_kv_cache_tensors function). This test verifies
that the views are compatible: writing a mamba block
will not corrupt an attention block and vice-versa
'''
current_platform.seed_everything(42)
update_environment_variables({
'RANK': "0",
'LOCAL_RANK': "0",
'WORLD_SIZE': "1",
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16)
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig(
model="ibm-granite/granite-4.0-tiny-preview",
dtype="float16",
)
cache_config = CacheConfig(
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
layer_2 = "model.layers.2.mixer"
layer_3 = "model.layers.3.mixer"
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(
parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
prefix=key,
)
for key in [layer_2, layer_3, layer_4, layer_5]:
fwd_context[key] = MambaMixer2(
hidden_size = hf_config.hidden_size,
ssm_state_size = hf_config.mamba_d_state,
conv_kernel_size = hf_config.mamba_d_conv,
intermediate_size = hf_config.mamba_expand *\
hf_config.hidden_size,
use_conv_bias = hf_config.mamba_conv_bias,
use_bias = hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
prefix=key,
)
# suppress var not used error
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer
assert attn_shape[0] == num_blocks
attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]),
device=DEVICE,
fill_value=3.33)
conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]),
device=DEVICE,
fill_value=6.66)
ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]),
device=DEVICE,
fill_value=9.99)
# fill all attention blocks with constant
for layer in [layer_0, layer_1]:
vllm_ctx[layer].kv_cache[0][
blocks0, :] = attn_blocks_constant.detach().clone()
# fill all mamba blocks with constant
for layer in [layer_2, layer_3, layer_4, layer_5]:
vllm_ctx[layer].kv_cache[0][0][
blocks1, :] = conv_blocks_constant.detach().clone()
vllm_ctx[layer].kv_cache[0][1][
blocks1, :] = ssm_blocks_constant.detach().clone()
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :],
attn_blocks_constant)
for layer in [layer_2, layer_3, layer_4, layer_5]:
assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :],
conv_blocks_constant)
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
ssm_blocks_constant)

View File

@ -179,12 +179,12 @@ async def get_request(
delay_ts = [delay * normalize_factor for delay in delay_ts]
start_ts = time.time()
request_index = 0
for request_index, request in enumerate(input_requests):
current_ts = time.time()
sleep_interval_s = start_ts + delay_ts[request_index] - current_ts
if sleep_interval_s > 0:
await asyncio.sleep(sleep_interval_s)
if delay_ts[request_index] > 0:
current_ts = time.time()
sleep_interval_s = start_ts + delay_ts[request_index] - current_ts
if sleep_interval_s > 0:
await asyncio.sleep(sleep_interval_s)
yield request, request_rates[request_index]

View File

@ -159,6 +159,9 @@ if flashinfer_comm is not None:
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
@ -173,12 +176,16 @@ if flashinfer_comm is not None:
max_token_num: int,
norm_out: Optional[torch.Tensor] = None,
) -> None:
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
1] * allreduce_in.element_size() <= min(
_FI_MAX_SIZES[world_size],
max_token_num * allreduce_in.shape[0] *
allreduce_in.element_size(),
)
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"

View File

@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
@config
@ -316,6 +318,13 @@ class ModelConfig:
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
"""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
@ -2116,6 +2125,15 @@ class ParallelConfig:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.")
if not self.enable_expert_parallel:
raise ValueError(
"enable_expert_parallel must be True to use EPLB.")
if self.tensor_parallel_size * self.data_parallel_size <= 1:
raise ValueError(
"EPLB requires tensor_parallel_size or data_parallel_size "
f"to be greater than 1, but got "
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.num_redundant_experts != 0:
raise ValueError(
@ -2136,10 +2154,11 @@ class ParallelConfig:
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:
raise ValueError("Unable to load Ray which is "
raise ValueError("Unable to load Ray: "
f"{ray_utils.ray_import_err}. Ray is "
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`.") from ray_utils.ray_import_err
"ray`.")
backend = "ray"
elif self.data_parallel_backend == "ray":
logger.info("Using ray distributed inference because "

View File

@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
KVTransferConfig, LoadConfig, LoadFormat,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -315,7 +315,6 @@ class EngineArgs:
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = ModelConfig.disable_sliding_window
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
use_v2_block_manager: bool = True
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
@ -327,6 +326,7 @@ class EngineArgs:
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False
revision: Optional[str] = ModelConfig.revision
code_revision: Optional[str] = ModelConfig.code_revision
@ -366,7 +366,6 @@ class EngineArgs:
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
@ -494,6 +493,8 @@ class EngineArgs:
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
**model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])
model_group.add_argument("--disable-cascade-attn",
@ -755,16 +756,6 @@ class EngineArgs:
"--max-prompt-adapter-token",
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
device_group = parser.add_argument_group(
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device",
**device_kwargs["device"],
deprecated=True)
# Speculative arguments
speculative_group = parser.add_argument_group(
title="SpeculativeConfig",
@ -866,15 +857,6 @@ class EngineArgs:
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--use-v2-block-manager',
action='store_true',
default=True,
deprecated=True,
help='[DEPRECATED] block manager v1 has been '
'removed and SelfAttnBlockSpaceManager (i.e. '
'block manager v2) is now the default. '
'Setting this flag to True or False'
' has no effect on vLLM behavior.')
parser.add_argument('--disable-log-stats',
action='store_true',
help='Disable logging statistics.')
@ -923,6 +905,7 @@ class EngineArgs:
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,
@ -1387,22 +1370,8 @@ class EngineArgs:
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
will_use_fa = (
current_platform.is_cuda()
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if (current_platform.is_rocm()
or (current_platform.is_cuda()
and current_platform.is_device_capability(100))
or current_platform.is_tpu()):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
supported = current_platform.is_kv_cache_dtype_supported(
self.kv_cache_dtype)
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)

View File

@ -438,6 +438,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Async version of
@ -468,6 +469,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
tokenization_kwargs=tokenization_kwargs,
)
if isinstance(params, SamplingParams) and \
@ -862,6 +864,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
@ -889,6 +892,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
tokenization_kwargs=tokenization_kwargs,
)
return stream.generator()
@ -996,6 +1000,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
@ -1070,6 +1075,7 @@ class AsyncLLMEngine(EngineClient):
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:

View File

@ -965,6 +965,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@ -981,6 +982,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@ -997,6 +999,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@ -1014,6 +1017,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@ -1031,6 +1035,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@ -1046,6 +1051,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
...
@ -1066,6 +1072,7 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input
prompts.
@ -1131,9 +1138,11 @@ class LLM:
for pooling_param in pooling_params:
pooling_param.verify(pooling_task, model_config)
tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
if tokenization_kwargs is None:
tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs)
self._validate_and_add_requests(
prompts=parsed_prompts,

View File

@ -841,7 +841,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
return data
# if "tool_choice" is specified -- validation
if "tool_choice" in data:
if "tool_choice" in data and data["tool_choice"] is not None:
# ensure that if "tool choice" is specified, tools are present
if "tools" not in data or data["tools"] is None:
@ -853,7 +853,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if data["tool_choice"] not in [
"auto", "required"
] and not isinstance(data["tool_choice"], dict):
raise NotImplementedError(
raise ValueError(
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'

View File

@ -17,6 +17,7 @@ from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .xlam_tool_parser import xLAMToolParser
__all__ = [
@ -38,4 +39,5 @@ __all__ = [
"KimiK2ToolParser",
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
]

View File

@ -0,0 +1,669 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import uuid
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module(["qwen3_coder"])
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
# Sentinel tokens for streaming mode
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Streaming state variables
self.current_tool_index: int = 0
self.header_sent: bool = False
self.current_tool_string_id: Optional[str] = None
self.current_function_name: Optional[str] = None
self.current_param_name: Optional[str] = None
self.current_param_value: str = ""
self.param_count: int = 0
self.in_param: bool = False
self.in_function: bool = False
self.accumulated_text: str = ""
self.json_started: bool = False
self.json_closed: bool = False
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns
self.tool_call_complete_regex = re.compile(
r"<tool_call>(.*?)</tool_call>", re.DOTALL)
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Qwen3 XML Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
logger.debug("vLLM Successfully import tool parser %s !",
self.__class__.__name__)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_string_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
def get_arguments_config(func_name: str) -> dict:
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function")
and hasattr(config.function, "name")):
continue
if (config.type == "function"
and config.function.name == func_name):
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.",
func_name)
return {}
def convert_param_value(param_value: str, param_name: str,
param_config: dict, func_name: str) -> Any:
# Handle null value for any type
if param_value.lower() == "null":
return None
converted_value: Any
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.", param_name, func_name)
return param_value
if (isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]):
param_type = str(
param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in [
"string", "str", "text", "varchar", "char", "enum"
]:
return param_value
elif (param_type.startswith("int") or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")):
try:
converted_value = int(param_value)
return converted_value
except ValueError:
logger.warning(
"Parsed value '%s' of parameter '%s' is not an "
"integer in tool '%s', degenerating to string.",
param_value, param_name, func_name)
return param_value
elif (param_type.startswith("num")
or param_type.startswith("float")):
try:
float_param_value = float(param_value)
converted_value = (float_param_value if float_param_value -
int(float_param_value) != 0 else
int(float_param_value))
return converted_value
except ValueError:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a "
"boolean (`true` of `false`) in tool '%s', "
"degenerating to false.", param_value, param_name,
func_name)
return param_value == "true"
else:
if param_type == "object" or param_type.startswith("dict"):
try:
converted_value = json.loads(param_value)
return converted_value
except json.JSONDecodeError:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a "
"valid JSON object in tool '%s', will try other "
"methods to parse it.", param_value, param_name,
func_name)
try:
converted_value = eval(param_value)
return converted_value
except Exception:
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be "
"converted via Python `eval()` in tool '%s', "
"degenerating to string.", param_value, param_name,
func_name)
return param_value
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = get_arguments_config(function_name)
parameters = function_call_str[end_index + 1:]
param_dict = {}
for match in self.tool_call_parameter_regex.findall(parameters):
match_text = match[0] if match[0] else match[1]
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1:])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = convert_param_value(
param_value, param_name, param_config, function_name)
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(param_dict,
ensure_ascii=False)),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(
self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
function_calls = self._get_function_calls(model_output)
if len(function_calls) == 0:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set
# finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append({
"name":
tool_call.function.name,
"arguments":
tool_call.function.arguments,
})
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
content_index = (content_index if content_index >= 0 else
model_output.find(self.tool_call_prefix))
content = model_output[:content_index] # .rstrip()
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# If no delta text, return None unless it's an EOS token after tool
# calls
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all
# tools
if (delta_token_ids
and self.tool_call_end_token_id not in delta_token_ids):
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text))
# If we have completed tool calls and populated
# prev_tool_call_arr
if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0):
# Check if all tool calls are closed
open_calls = (
current_text.count(self.tool_call_start_token) -
current_text.count(self.tool_call_end_token))
if open_calls == 0:
# Return empty delta message to allow finish_reason
# processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if this is the first call (reset state if needed)
if not previous_text:
self._reset_streaming_state()
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
# Check if there are more tool calls
tool_starts_count = current_text.count(
self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[:delta_text.index(
self.tool_call_start_token)]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
tool_starts: list[int] = []
idx = 0
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token,
tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[tool_start_idx:tool_end_idx +
len(self.tool_call_end_token)]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = (tool_text.find(self.tool_call_prefix) +
len(self.tool_call_prefix))
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_string_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we
# detect a tool call. This ensures
# finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr)
if not already_added:
self.prev_tool_call_arr.append({
"name": self.current_function_name,
"arguments":
"{}", # Placeholder, will be updated later
})
# Send header with function info
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_string_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""),
type="function",
)
])
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if (not self.json_started
and self.parameter_prefix not in delta_text):
self.json_started = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
])
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr
# with final arguments. Find the function content
func_start = (tool_text.find(self.tool_call_prefix) +
len(self.tool_call_prefix))
func_content_end = tool_text.find(self.function_end_token,
func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with
# complete arguments
for i, tool in enumerate(self.prev_tool_call_arr):
if (tool.get("name") ==
parsed_tool.function.name):
self.prev_tool_call_arr[i]["arguments"] = (
parsed_tool.function.arguments)
break
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
])
# Reset state for next tool
self.in_function = False
self.json_closed = True
return result
# Look for parameters
# Count how many complete parameters we have processed
complete_params = tool_text.count(self.parameter_end_token)
# Check if we should start a new parameter
if not self.in_param and self.param_count < complete_params:
# Find the unprocessed parameter
# Count parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
if len(param_starts) > self.param_count:
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(
self.parameter_end_token)
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Build complete JSON fragment for this parameter
if self.param_count == 0:
json_fragment = (
'"' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
else:
json_fragment = (
', "' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
self.param_count += 1
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=json_fragment),
)
])
# Continue parameter value
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if (not self.current_param_value
and value_chunk.startswith("\n")):
value_chunk = value_chunk[1:]
# Calculate incremental JSON
full_value = self.current_param_value + value_chunk
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
if self.current_param_value else "")
full_escaped = json.dumps(full_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
self.in_param = False
self.current_param_value = ""
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped + '"'),
)
])
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if (not self.current_param_value
and value_chunk.startswith("\n")):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (json.dumps(
self.current_param_value)[1:-1]
if self.current_param_value else "")
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
if delta_escaped:
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped),
)
])
return None

View File

@ -58,6 +58,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
}
# These non-vLLM env vars are copied from the driver to workers
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
uses_ray: bool = True
def _init_executor(self) -> None:
@ -67,8 +70,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
# For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu():
# For TPU or XPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu() or current_platform.is_xpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API
@ -326,7 +329,8 @@ class RayDistributedExecutor(DistributedExecutorBase):
# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars),
additional_vars=set(current_platform.additional_env_vars).union(
self.ADDITIONAL_ENV_VARS),
destination="workers")
# Copy existing env vars to each worker's args

View File

@ -145,7 +145,9 @@ try:
except ImportError as e:
ray = None # type: ignore
ray_import_err = e
# only capture string to avoid variable references in the traceback that can
# prevent garbage collection in some cases
ray_import_err = str(e)
RayWorkerWrapper = None # type: ignore
@ -157,8 +159,8 @@ def ray_is_available() -> bool:
def assert_ray_available():
"""Raise an exception if Ray is not available."""
if ray is None:
raise ValueError("Failed to import Ray, please install Ray with "
"`pip install ray`.") from ray_import_err
raise ValueError(f"Failed to import Ray: {ray_import_err}."
"Please install Ray with `pip install ray`.")
def _verify_bundles(placement_group: "PlacementGroup",

View File

@ -464,10 +464,11 @@ class FusedMoEConfig:
)
else:
_quant_config = FusedMoEQuantConfig()
logger.warning_once("MoE DP setup unable to determine "
"quantization scheme or unsupported "
"quantization type. This model will "
"not run with DP enabled.")
if moe_parallel_config.dp_size > 1:
logger.warning_once("MoE DP setup unable to determine "
"quantization scheme or unsupported "
"quantization type. This model will "
"not run with DP enabled.")
else:
_quant_config = quant_config

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache,
extract_required_args)
from vllm.scalar_type import scalar_types
@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = ops.shuffle_rows(a1q, a_map)
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
if per_act_token else a1q_scale)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
else:
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
non_blocking=True)
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched?
@ -211,10 +223,6 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
@ -231,10 +239,6 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.use_batched_format = use_batched_format
@property
@ -314,8 +318,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
@ -329,10 +332,6 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
@ -360,17 +359,6 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@ -403,10 +391,6 @@ def cutlass_moe_fp8(
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
use_batched_format=False,
),
)

View File

@ -181,12 +181,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
g2_alphas,
]
_ = flashinfer_cutlass_fused_moe(
hidden_states,
topk_ids.to(torch.int),
topk_weights,
input=hidden_states,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
# FlashInfer API requires weight to be long for nvfp4
w1.view(torch.long),
w2.view(torch.long),
fc1_expert_weights=w1.view(torch.long),
fc2_expert_weights=w2.view(torch.long),
output_dtype=out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,

View File

@ -11,7 +11,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input)
from vllm.utils.flashinfer import fp4_swizzle_blockscale
from vllm.utils.flashinfer import block_scale_interleave
def get_local_sizes(local_tokens):
@ -92,7 +92,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dim=0,
sizes=get_local_sizes(local_tokens))
a1_m, a1_n = a1q.shape
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
a1q_scale = block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -111,6 +111,8 @@ def moe_align_block_size_triton(
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = cdiv(numel, num_experts)
sorted_token_ids.fill_(numel)
expert_ids.zero_()
moe_align_block_size_stage1[grid](
topk_ids,
@ -205,11 +207,8 @@ def moe_align_block_size(
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),

View File

@ -259,6 +259,8 @@ class LinearBase(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
@ -300,6 +302,12 @@ class ReplicatedLinear(LinearBase):
*,
return_bias: bool = True,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
@ -311,7 +319,8 @@ class ReplicatedLinear(LinearBase):
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_partition_sizes,
self.input_size,
self.output_size,
self.params_dtype,
@ -367,6 +376,73 @@ class ReplicatedLinear(LinearBase):
return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
param[shard_offset:shard_offset + shard_size] = loaded_weight
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.

View File

@ -332,6 +332,12 @@ class CompressedTensorsConfig(QuantizationConfig):
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
return (self._check_scheme_supported(
100, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights quantized.

View File

@ -83,7 +83,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@ -740,6 +741,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
self.topk_indices_dtype = None
self.fused_experts = None # type: ignore
self.disable_expert_map = False
self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -859,21 +862,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full((layer.local_num_experts, ),
layer.hidden_size,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((layer.local_num_experts, ),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((layer.local_num_experts, ),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
@ -896,10 +884,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format,
)
@ -946,12 +930,33 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
e_score_correction_bias=e_score_correction_bias)
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
# Triton fused_experts is faster in small batch sizes on SM100.
# Fall back to fused_experts in small batch sizes.
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.fused_experts is None:
# If no modular kernel is provided, use cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
@ -968,10 +973,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)

View File

@ -257,9 +257,16 @@ class Fp8LinearMethod(LinearMethodBase):
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
is_tp_split = (tp_size > 1 and
output_size // output_size_per_partition == tp_size)
is_merged_gemm = len(output_partition_sizes) > 1
if is_tp_split or is_merged_gemm:
sizes_to_check = output_partition_sizes
if not is_tp_split and is_merged_gemm:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check = output_partition_sizes[:-1]
for output_partition_size in sizes_to_check:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "

View File

@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
@ -397,8 +398,7 @@ def per_token_group_quant_fp8(
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
@ -407,6 +407,15 @@ def per_token_group_quant_fp8(
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
# prefer CUDA kernel if available
if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0)
return x_q, x_s
# TRITON FALLBACK
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
@ -423,7 +432,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
@ -439,7 +448,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,

View File

@ -25,7 +25,8 @@ from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.registry import _TRANSFORMERS_MODELS
from vllm.model_executor.models.registry import (_PREVIOUSLY_SUPPORTED_MODELS,
_TRANSFORMERS_MODELS)
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -261,6 +262,14 @@ def get_model_architecture(
vllm_not_supported = False
break
if any(arch in _PREVIOUSLY_SUPPORTED_MODELS for arch in architectures):
previous_version = _PREVIOUSLY_SUPPORTED_MODELS[architectures[0]]
raise ValueError(
f"Model architecture {architectures[0]} was supported"
f" in vLLM until version {previous_version}, and is "
"not supported anymore. Please use an older version"
" of vLLM if you want to use this model architecture.")
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl == ModelImpl.AUTO and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)

View File

@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator(
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()
streamer.stream_files(hf_weights_files)
total_tensors = sum(
len(tensors_meta)
for tensors_meta in streamer.files_to_tensors_metadata.values())
tensor_iter = tqdm(
streamer.get_tensors(),
total=total_tensors,
desc="Loading safetensors using Runai Model Streamer",
bar_format=_BAR_FORMAT,
disable=not enable_tqdm(use_tqdm_on_load),
)
yield from tensor_iter
def fastsafetensors_weights_iterator(

View File

@ -0,0 +1,347 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2023-2025 vLLM Team
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Inference-only Arcee (AFM) model adds support for ReLU^2 feed-forward
# activation.
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer,
make_empty_intermediate_tensors_factory, make_layers)
class ArceeMLP(nn.Module):
"""Feed-forward layer for Arcee using ReLU^2 activation
(no gating as in LLaMA)."""
def __init__(self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[Any] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True) -> None:
super().__init__()
# Single linear projection up to intermediate size
# (no separate gate projection)
self.up_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
# Down projection back to hidden size
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "relu2":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only 'relu2' is supported for AFM.")
# Define ReLU^2 activation: (ReLU(x))^2 elementwise
self.act_fn = ReLUSquaredActivation()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x) # Project to intermediate size
x = self.act_fn(x) # Apply ReLU^2 activation elementwise
x, _ = self.down_proj(x) # Project back down to hidden size
return x
class ArceeDecoderLayer(nn.Module):
"""Transformer decoder block for Arcee, with self-attention and
ReLU^2 MLP."""
def __init__(self,
config: LlamaConfig,
cache_config: Optional[Any] = None,
quant_config: Optional[Any] = None,
prefix: str = "") -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Rotary embedding parameters (reuse LLaMA defaults)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Determine if attention bias is needed (some variants use bias terms)
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
bias_o_proj = attention_bias
if hasattr(config, "qkv_bias"):
attention_bias = config.qkv_bias
# Self-Attention (using LLaMA's attention structure)
from vllm.model_executor.models.llama import (
LlamaAttention) # import here to avoid circular import
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=getattr(
config, "attn_type",
"decoder"), # assume decoder (causal) unless specified
)
# MLP with ReLU^2 activation
self.mlp = ArceeMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
# Layer normalization layers (RMSNorm as in LLaMA)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
# Self-Attention block
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
# Fused residual add + layernorm if supported
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Feed-forward block
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class ArceeModel(nn.Module):
"""The transformer model backbone for Arcee (embedding layer + stacked
decoder blocks + final norm)."""
def __init__(self,
*,
vllm_config,
prefix: str = "",
layer_type: type[nn.Module] = ArceeDecoderLayer) -> None:
super().__init__()
config: LlamaConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.config = config
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
# Word embeddings (parallelized if using pipeline parallel)
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer(
) # placeholder on non-embedding ranks
# Build decoder layers across pipeline ranks
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
# Final RMSNorm on the last pipeline stage
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
# For optional capturing of intermediate hidden states
# (not used by default)
self.aux_hidden_state_layers: tuple[int, ...] = tuple()
# Prepare factory for empty intermediate tensors
# (for pipeline scheduling)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
# Embedding lookup (on first pipeline rank)
if get_pp_group().is_first_rank:
hidden_states = (inputs_embeds if inputs_embeds is not None else
self.get_input_embeddings(input_ids))
residual = None
else:
assert intermediate_tensors is not None, (
"IntermediateTensors must be provided for non-first "
"pipeline ranks")
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states +
residual) # capture pre-layer hidden state if needed
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
# Send intermediate results to the next pipeline stage
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
# On last rank: apply final layer norm
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"""Arcee Model for causal language modeling, integrated with vLLM
runtime."""
# Map fused module names to their sub-module components
# (for quantization and LoRA)
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
def __init__(self, *, vllm_config, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
# Initialize the inner Transformer model (ArceeModel)
self.model = ArceeModel(vllm_config=vllm_config,
prefix=f"{prefix}.model")
# On the last pipeline stage, set up the LM head and logits processor
if get_pp_group().is_last_rank:
# Determine vocabulary size (including any LoRA extra tokens
# for padded LM head)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=vllm_config.quant_config,
bias=getattr(config, "lm_head_bias", False),
prefix=f"{prefix}.lm_head",
)
if config.tie_word_embeddings:
# Tie output weights with input embedding matrix
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
# Placeholder for lm_head on non-last ranks
self.lm_head = PPMissingLayer()
# Provide a reference to the model's method for generating empty
# tensors (used in pipeline parallel schedule)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
# Forward pass through the Arcee model backbone
model_output = self.model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata) -> Optional[torch.Tensor]:
# Compute final logits from hidden states (last pipeline rank only)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
"""Load weights into the model (delegates to inner model and handles
tied embeddings)."""
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
skip_substrs=["gate_proj"])
# AutoWeightLoader handles weight name remapping, including fusing
# separate q_proj, k_proj, v_proj into qkv_proj
return loader.load_weights(weights)

View File

@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -336,7 +337,7 @@ class DeepseekV2Attention(nn.Module):
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
@ -407,14 +408,24 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.fused_qkv_a_proj = MergedReplicatedLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj")
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
@ -427,13 +438,6 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
@ -495,15 +499,24 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
@ -837,6 +850,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
@ -870,7 +885,16 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if ((param_name == "fused_qkv_a_proj")
and name_mapped not in params_dict):
continue
else:
name = name_mapped
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

View File

@ -624,13 +624,9 @@ class SupportsQuant:
instance.quant_config = quant_config
# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if instance.packed_modules_mapping is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping)

View File

@ -33,6 +33,7 @@ _TEXT_GENERATION_MODELS = {
# [Decoder-only]
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArceeForCausalLM": ("arcee", "ArceeForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
@ -275,6 +276,8 @@ _SUBPROCESS_COMMAND = [
sys.executable, "-m", "vllm.model_executor.models.registry"
]
_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"}
@dataclass(frozen=True)
class _ModelInfo:

View File

@ -13,8 +13,7 @@ from transformers import LlavaConfig as HfLlavaConfig
from transformers import PretrainedConfig, SiglipVisionConfig
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
from transformers.models.llava import LlavaProcessor
from transformers.processing_utils import (ProcessingKwargs, Unpack,
_validate_images_text_input_order)
from transformers.processing_utils import ProcessingKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import VllmConfig
@ -94,9 +93,6 @@ class TarsierProcessor(LlavaProcessor):
raise ValueError(
"You have to specify at least one of `images` or `text`.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
TarsierProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,

View File

@ -315,16 +315,16 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if return_mm_hashes:
raise ValueError(
"TransformersForMultimodalLM doesn't support mm hashing yet! "
"Probably you didn't set `disable_mm_preprocessor_cache=True`")
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
(prompt_ids, processed_data,
mm_token_type_ids) = self._apply_hf_processor_text_mm(
@ -375,12 +375,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
num_image_patches),
)
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs)
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=None,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
@ -412,7 +414,7 @@ class ConfigOverride:
setattr(self.config, key, value)
class TransformersModel(nn.Module):
class TransformersModel:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -452,9 +454,6 @@ class TransformersModel(nn.Module):
# method after v4.54.0 is released
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config(
config,
torch_dtype=model_config.dtype,
@ -618,9 +617,6 @@ class TransformersModel(nn.Module):
for child in module.children():
self.init_parameters(child)
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def forward(
self,
input_ids: Optional[torch.Tensor],
@ -692,7 +688,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.config = config
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@ -714,22 +712,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
self.transformers_model.make_empty_intermediate_tensors)
def forward(
self,
@ -738,8 +721,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
@ -753,12 +737,10 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
skip_prefixes = ["lm_head."
] if self.config.tie_word_embeddings else None
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
@MULTIMODAL_REGISTRY.register_processor(
@ -770,6 +752,29 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
@ -778,7 +783,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config()
if get_pp_group().is_last_rank:
@ -801,32 +808,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@property
def hf_to_vllm_mapper(self):
# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
prefix_mapper = {
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
}
# Don't change the order for QwenVL
if 'Qwen2' in self.config.__class__.__name__:
prefix_mapper["model"] = "model.language_model"
prefix_mapper["visual"] = "model.visual"
return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
self.transformers_model.make_empty_intermediate_tensors)
def forward(
self,
@ -846,8 +828,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids, multimodal_embeds)
input_ids = None
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
@ -896,7 +879,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)
vision_embeddings = self.model.model.get_image_features(
vision_embeddings = self.model.get_image_features(
pixel_values,
**{
k: v.flatten(0, 1)
@ -926,7 +909,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids: torch.Tensor,
multimodal_embeddings=None,
) -> torch.Tensor:
inputs_embeds = self.model.model.get_input_embeddings()(input_ids)
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)

View File

@ -275,7 +275,7 @@ class MultiModalProfiler(Generic[_I]):
if total_mm_tokens > seq_len:
logger.warning_once(
"The sequence length (%d) is smaller than the pre-defined"
" wosrt-case total number of multimodal tokens (%d). "
" worst-case total number of multimodal tokens (%d). "
"This may cause certain multi-modal inputs to fail during "
"inference. To avoid this, you should increase "
"`max_model_len` or reduce `mm_counts`.",

View File

@ -182,9 +182,6 @@ class CudaPlatformBase(Platform):
compilation_config.use_cudagraph = False
if model_config is not None:
model_config.enforce_eager = True
# TODO (varun): Turning this ON gives incorrect results for the
# Deepseek-V2-lite model.
vllm_config.compilation_config.use_inductor = False
@classmethod
def get_current_memory_usage(cls,
@ -459,6 +456,19 @@ class CudaPlatformBase(Platform):
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if cls.is_device_capability(100):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
return supported
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -543,6 +543,13 @@ class Platform:
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
"""
Returns if the kv_cache_dtype is supported by the current platform.
"""
return False
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -454,3 +454,7 @@ class RocmPlatform(Platform):
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
return True

View File

@ -190,6 +190,10 @@ class TpuPlatform(Platform):
and params.sampling_type == SamplingType.RANDOM_SEED):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
return True
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform

View File

@ -43,6 +43,8 @@ def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None,
exclude_vars: A set of vllm defined environment variables to exclude
from copying.
additional_vars: A set of additional environment variables to copy.
If a variable is in both exclude_vars and additional_vars, it will
be excluded.
destination: The destination of the environment variables.
Returns:
A set of environment variables to copy.
@ -52,10 +54,9 @@ def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None,
env_vars_to_copy = {
v
for v in envs.environment_variables
for v in set(envs.environment_variables).union(additional_vars)
if v not in exclude_vars and v not in RAY_NON_CARRY_OVER_ENV_VARS
}
env_vars_to_copy.update(additional_vars)
to_destination = " to " + destination if destination is not None else ""

View File

@ -10,7 +10,7 @@ MODELS_ON_S3 = [
"allenai/OLMoE-1B-7B-0924-Instruct",
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test",
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"ArthurZ/Ilama-3.2-1B",
"hmellor/Ilama-3.2-1B",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"BAAI/bge-reranker-v2-m3",

View File

@ -37,6 +37,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MiniMaxText01Config,
MiniMaxVL01Config, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
OvisConfig, RWConfig,
SkyworkR1VChatConfig, SolarConfig,
@ -80,6 +81,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"dbrx": DbrxConfig,
"deepseek_vl_v2": DeepseekVLV2Config,
"kimi_vl": KimiVLConfig,
"Llama_Nemotron_Nano_VL": Nemotron_Nano_VL_Config,
"mpt": MPTConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)

View File

@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
@ -50,6 +51,7 @@ __all__ = [
"KimiVLConfig",
"NemotronConfig",
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"OvisConfig",
"SkyworkR1VChatConfig",

View File

@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# yapf: disable
# ruff: noqa: E501
# Adapted from
# https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/configuration.py
# --------------------------------------------------------
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B under MIT License
# LICENSE is in incl_licenses directory.
# --------------------------------------------------------
from transformers import LlamaConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.dynamic_module_utils import get_class_from_dynamic_module
class Nemotron_Nano_VL_Config(PretrainedConfig):
model_type = 'Llama_Nemotron_Nano_VL'
is_composition = True
def __init__(
self,
vision_config=None,
llm_config=None,
force_image_size=None,
downsample_ratio=0.5,
template=None,
ps_version='v1',
image_tag_type="internvl",
projector_hidden_size=4096,
vit_hidden_size=1280,
**kwargs
):
super().__init__(**kwargs)
if vision_config is not None:
assert "auto_map" in vision_config and "AutoConfig" in vision_config["auto_map"]
vision_auto_config = get_class_from_dynamic_module(*vision_config["auto_map"]["AutoConfig"].split("--")[::-1])
self.vision_config = vision_auto_config(**vision_config)
else:
self.vision_config = PretrainedConfig()
if llm_config is None:
self.text_config = LlamaConfig()
else:
self.text_config = LlamaConfig(**llm_config)
# Assign configuration values
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template # TODO move out of here and into the tokenizer
self.ps_version = ps_version # Pixel shuffle version
self.image_tag_type = image_tag_type # TODO: into the tokenizer too?
self.projector_hidden_size = projector_hidden_size
self.vit_hidden_size = vit_hidden_size

View File

@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
return None
if not has_deep_gemm():
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
else:
_dg = importlib.import_module("deep_gemm") # type: ignore
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
_fp8_gemm_nt_impl = _resolve_symbol(
_dg,
"fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt",
)
def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
_per_block_cast_impl
# fast path
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
or _grouped_masked_impl is not None
or _per_block_cast_impl is not None):
return
if not has_deep_gemm():
return
_dg = importlib.import_module("deep_gemm")
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt")
_grouped_impl = _resolve_symbol(
_dg,
"m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
)
_dg, "m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
_grouped_masked_impl = _resolve_symbol(
_dg,
"fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
)
_dg, "fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try:
_math_mod = importlib.import_module(
@ -80,24 +86,28 @@ else:
def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(*args, **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(*args, **kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(*args, **kwargs)
def per_block_cast_to_fp8(x, *args, **kwargs):
_lazy_init()
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x, use_ue8m0=True)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils

View File

@ -69,8 +69,8 @@ flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
"fp4_swizzle_blockscale")
block_scale_interleave = _lazy_import_wrapper("flashinfer",
"block_scale_interleave")
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
required_functions = [
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "fp4_swizzle_blockscale"),
("flashinfer", "block_scale_interleave"),
]
for module_name, attr_name in required_functions:
@ -110,7 +110,7 @@ __all__ = [
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"fp4_swizzle_blockscale",
"block_scale_interleave",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",

View File

@ -214,21 +214,18 @@ class BlockPool:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: list[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
# In order to only iterate the list once, we duplicated code a bit
if self.enable_caching:
for block in ret:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
@ -243,22 +240,27 @@ class BlockPool:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if block_hash is None:
# The block doesn't have hash, eviction is not needed
return False
blocks_by_id = self.cached_block_hash_to_block.get(block_hash)
if blocks_by_id is None:
# block_hash not found in cached_block_hash_to_block,
# eviction is not needed
return False
block.reset_hash()
blocks_by_id.pop(block.block_id, None)
if len(blocks_by_id) == 0:
del self.cached_block_hash_to_block[block_hash]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
return True
return False
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
return True
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
@ -284,11 +286,14 @@ class BlockPool:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append(block)
# Materialize the iterable to allow multiple passes.
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
self.free_block_queue.append_n([
block for block in blocks_list
if block.ref_cnt == 0 and not block.is_null
])
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF

View File

@ -154,6 +154,8 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached.
is_null: bool = False
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
# avoid function calls.
def incr_ref(self):
self.ref_cnt += 1
@ -273,6 +275,39 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks -= 1
return first_block
def popleft_n(self, n: int) -> list[KVCacheBlock]:
"""Pop the first n free blocks and reduce num_free_blocks by n.
Args:
n: The number of blocks to pop.
Returns:
A list of n free blocks.
"""
if n == 0:
return []
assert self.num_free_blocks >= n
self.num_free_blocks -= n
curr_block = self.fake_free_list_head.next_free_block
# Pop n blocks from the head of the list
ret = []
for _ in range(n):
assert curr_block is not None
ret.append(curr_block)
last_block = curr_block
curr_block = curr_block.next_free_block
# Reset prev_free_block and next_free_block of all popped blocks
last_block.prev_free_block = None
last_block.next_free_block = None
if curr_block is not None:
# The queue is not empty, connect the fake head to
# the new first block.
self.fake_free_list_head.next_free_block = curr_block
curr_block.prev_free_block = self.fake_free_list_head
return ret
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
@ -315,6 +350,29 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks += 1
def append_n(self, blocks: list[KVCacheBlock]) -> None:
"""Put a list of blocks back into the free list
Args:
blocks: The blocks to append.
"""
if len(blocks) == 0:
return
self.num_free_blocks += len(blocks)
last_block = self.fake_free_list_tail.prev_free_block
assert last_block is not None, (
"prev_free_block of fake_free_list_tail should always exist")
# Add inter-connections between consecutive blocks
for block in blocks:
block.prev_free_block = last_block
last_block.next_free_block = block
last_block = block
# Connect the last block of <blocks> to the fake tail
last_block.next_free_block = self.fake_free_list_tail
self.fake_free_list_tail.prev_free_block = last_block
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
@ -348,9 +406,9 @@ def need_extra_keys(request: Request) -> bool:
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# Request with provided cache salt need to include the salt.
return bool(request.mm_positions) or (request.lora_request
is not None) or (request.cache_salt
is not None)
return bool(request.mm_hashes) or (request.lora_request
is not None) or (request.cache_salt
is not None)
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,

View File

@ -437,6 +437,7 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
@ -465,6 +466,7 @@ class AsyncLLM(EngineClient):
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
)
# The output_handler task pushes items into the queue.

View File

@ -335,14 +335,19 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
if not batch_update:
return
needs_update: bool = False
# Process added requests.
needs_update = bool(batch_update.added)
for index, params, _ in batch_update.added:
if isinstance(params, SamplingParams) and (lb :=
params.logit_bias):
self.biases[index] = lb
needs_update = True
else:
self.biases.pop(index, None)
# Drop biases metadata at batch index
if self.biases.pop(index, None) is not None:
# If a new request replaces an old request which
# specified biases, we should update processor tensors
needs_update = True
if self.biases:
# Process removed requests.
@ -419,7 +424,6 @@ class MinTokensLogitsProcessor(LogitsProcessor):
if batch_update:
# Process added requests.
needs_update |= bool(batch_update.added)
for index, params, output_tok_ids in batch_update.added:
if (isinstance(params, SamplingParams)
and (min_tokens := params.min_tokens)
@ -427,9 +431,13 @@ class MinTokensLogitsProcessor(LogitsProcessor):
# Replace request metadata at batch index
self.min_toks[index] = (min_tokens, output_tok_ids,
params.all_stop_token_ids)
needs_update = True
else:
# Drop request metadata at batch index
self.min_toks.pop(index, None)
# Drop min_toks metadata at batch index
if self.min_toks.pop(index, None) is not None:
# If a new request replaces an old request which
# specified min_toks, we should update processor tensors
needs_update = True
if self.min_toks:
# Process removed requests.

View File

@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Some utilities for logprobs, including logits."""
import torch
@torch.compile(dynamic=True)
def batched_count_greater_than(x: torch.Tensor,
values: torch.Tensor) -> torch.Tensor:
"""
Counts elements in each row of x that are greater than the corresponding
value in values. Use torch.compile to generate an optimized kernel for
this function. otherwise, it will create additional copies of the input
tensors and cause memory issues.
Args:
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
Returns:
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
"""
return (x >= values).sum(-1)

View File

@ -5,10 +5,12 @@
import torch
import torch.nn as nn
from vllm.config import LogprobsMode
from vllm.utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
@ -17,10 +19,11 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
def forward(
self,
@ -35,7 +38,10 @@ class Sampler(nn.Module):
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
# Use float32 for the logits.
logits = logits.to(torch.float32)
@ -50,6 +56,14 @@ class Sampler(nn.Module):
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Get the process logprobs or logits.
if num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "processed_logits":
raw_logprobs = logits.clone()
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility
@ -174,7 +188,7 @@ class Sampler(nn.Module):
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)

View File

@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
# TODO(houseroad): Add support for logprobs_mode.
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()

View File

@ -389,7 +389,7 @@ class InputBatch:
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
@ -590,7 +590,7 @@ class InputBatch:
def refresh_metadata(self):
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata
"""

View File

@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache_size = encoder_cache_size
# Sampler
self.sampler = Sampler()
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: Optional[EplbState] = None
"""
@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1

View File

@ -7,6 +7,7 @@ import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_world_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
@ -155,7 +156,8 @@ class XPUWorker(Worker):
current_platform.dist_backend)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
torch.distributed.all_reduce(torch.zeros(1).xpu(),
group=get_world_group().device_group)
# Set random seed.
set_random_seed(self.model_config.seed)