mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 07:13:37 +08:00
Merge branch 'main' into woosuk/model-runner-v2
This commit is contained in:
commit
ad2cf805ad
59
.buildkite/scripts/run-prime-rl-test.sh
Executable file
59
.buildkite/scripts/run-prime-rl-test.sh
Executable file
@ -0,0 +1,59 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Setup script for Prime-RL integration tests
|
||||||
|
# This script prepares the environment for running Prime-RL tests with nightly vLLM
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
|
||||||
|
PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git"
|
||||||
|
PRIME_RL_DIR="${REPO_ROOT}/prime-rl"
|
||||||
|
|
||||||
|
echo "Setting up Prime-RL integration test environment..."
|
||||||
|
|
||||||
|
# Clean up any existing Prime-RL directory
|
||||||
|
if [ -d "${PRIME_RL_DIR}" ]; then
|
||||||
|
echo "Removing existing Prime-RL directory..."
|
||||||
|
rm -rf "${PRIME_RL_DIR}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install UV if not available
|
||||||
|
if ! command -v uv &> /dev/null; then
|
||||||
|
echo "Installing UV package manager..."
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
source $HOME/.local/bin/env
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clone Prime-RL repository at specific branch for reproducible tests
|
||||||
|
PRIME_RL_BRANCH="integ-vllm-main"
|
||||||
|
echo "Cloning Prime-RL repository at branch: ${PRIME_RL_BRANCH}..."
|
||||||
|
git clone --branch "${PRIME_RL_BRANCH}" --single-branch "${PRIME_RL_REPO}" "${PRIME_RL_DIR}"
|
||||||
|
cd "${PRIME_RL_DIR}"
|
||||||
|
|
||||||
|
echo "Setting up UV project environment..."
|
||||||
|
export UV_PROJECT_ENVIRONMENT=/usr/local
|
||||||
|
ln -s /usr/bin/python3 /usr/local/bin/python
|
||||||
|
|
||||||
|
# Remove vllm pin from pyproject.toml
|
||||||
|
echo "Removing vllm pin from pyproject.toml..."
|
||||||
|
sed -i '/vllm==/d' pyproject.toml
|
||||||
|
|
||||||
|
# Sync Prime-RL dependencies
|
||||||
|
echo "Installing Prime-RL dependencies..."
|
||||||
|
uv sync --inexact && uv sync --inexact --all-extras
|
||||||
|
|
||||||
|
# Verify installation
|
||||||
|
echo "Verifying installations..."
|
||||||
|
uv run python -c "import vllm; print(f'vLLM version: {vllm.__version__}')"
|
||||||
|
uv run python -c "import prime_rl; print('Prime-RL imported successfully')"
|
||||||
|
|
||||||
|
echo "Prime-RL integration test environment setup complete!"
|
||||||
|
|
||||||
|
echo "Running Prime-RL integration tests..."
|
||||||
|
export WANDB_MODE=offline # this makes this test not require a WANDB_API_KEY
|
||||||
|
uv run pytest -vs tests/integration/test_rl.py -m gpu
|
||||||
|
|
||||||
|
echo "Prime-RL integration tests completed!"
|
||||||
@ -887,6 +887,8 @@ steps:
|
|||||||
- tests/v1/test_external_lb_dp.py
|
- tests/v1/test_external_lb_dp.py
|
||||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- vllm/v1/engine/
|
- vllm/v1/engine/
|
||||||
|
- vllm/v1/worker/
|
||||||
|
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||||
commands:
|
commands:
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||||
@ -908,6 +910,7 @@ steps:
|
|||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||||
|
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
|
||||||
|
|
||||||
- label: Plugin Tests (2 GPUs) # 40min
|
- label: Plugin Tests (2 GPUs) # 40min
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
@ -1042,3 +1045,15 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
|
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
|
||||||
|
|
||||||
|
##### RL Integration Tests #####
|
||||||
|
- label: Prime-RL Integration Test # 15min
|
||||||
|
timeout_in_minutes: 30
|
||||||
|
optional: true
|
||||||
|
num_gpus: 2
|
||||||
|
working_dir: "/vllm-workspace"
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- .buildkite/scripts/run-prime-rl-test.sh
|
||||||
|
commands:
|
||||||
|
- bash .buildkite/scripts/run-prime-rl-test.sh
|
||||||
|
|||||||
@ -449,7 +449,8 @@ async def benchmark(
|
|||||||
def prepare_extra_body(request) -> dict:
|
def prepare_extra_body(request) -> dict:
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
# Add the schema to the extra_body
|
# Add the schema to the extra_body
|
||||||
extra_body[request.structure_type] = request.schema
|
extra_body["structured_outputs"] = {}
|
||||||
|
extra_body["structured_outputs"][request.structure_type] = request.schema
|
||||||
return extra_body
|
return extra_body
|
||||||
|
|
||||||
print("Starting initial single prompt test run...")
|
print("Starting initial single prompt test run...")
|
||||||
|
|||||||
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe
|
||||||
|
kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
|
||||||
|
but use different quantization strategies and backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import nvtx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
# Weight shapes for different models: [num_experts, topk, hidden_size,
|
||||||
|
# intermediate_size]
|
||||||
|
WEIGHT_SHAPES_MOE = {
|
||||||
|
"mixtral-8x7b": [
|
||||||
|
[8, 2, 4096, 14336],
|
||||||
|
],
|
||||||
|
"deepseek-v2": [
|
||||||
|
[160, 6, 5120, 12288],
|
||||||
|
],
|
||||||
|
"custom-small": [
|
||||||
|
[8, 2, 2048, 7168],
|
||||||
|
],
|
||||||
|
"glm45-fp8": [
|
||||||
|
[128, 8, 4096, 1408],
|
||||||
|
],
|
||||||
|
"Llama-4-Maverick-17B-128E-Instruct-FP8": [
|
||||||
|
[128, 1, 5120, 8192],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_MODELS = [
|
||||||
|
"mixtral-8x7b",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
PER_ACT_TOKEN_OPTS = [False, True]
|
||||||
|
PER_OUT_CH_OPTS = [False, True]
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(
|
||||||
|
results: list,
|
||||||
|
model: str,
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
mkn: tuple[int, int, int],
|
||||||
|
):
|
||||||
|
(m, k, n) = mkn
|
||||||
|
|
||||||
|
dtype = torch.half
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# Create input activations
|
||||||
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
|
|
||||||
|
# Create weights
|
||||||
|
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||||
|
|
||||||
|
# Create FP8 quantized weights and scales for both kernels
|
||||||
|
w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE)
|
||||||
|
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE)
|
||||||
|
|
||||||
|
# Create scales based on quantization strategy
|
||||||
|
if per_out_ch:
|
||||||
|
# Per-channel quantization
|
||||||
|
w1_scale = torch.empty(
|
||||||
|
(num_experts, 2 * n, 1), device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
# Per-tensor quantization
|
||||||
|
w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||||
|
w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Quantize weights
|
||||||
|
for expert in range(num_experts):
|
||||||
|
if per_out_ch:
|
||||||
|
# Per-channel quantization - not yet implemented properly
|
||||||
|
# For now, fall back to per-tensor quantization
|
||||||
|
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
# Expand scalar scales to the expected per-channel shape
|
||||||
|
w1_scale[expert] = w1_scale_temp.expand(2 * n, 1)
|
||||||
|
w2_scale[expert] = w2_scale_temp.expand(k, 1)
|
||||||
|
else:
|
||||||
|
# Per-tensor quantization
|
||||||
|
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
# Store scalar scales in [1, 1] tensors
|
||||||
|
w1_scale[expert, 0, 0] = w1_scale_temp
|
||||||
|
w2_scale[expert, 0, 0] = w2_scale_temp
|
||||||
|
|
||||||
|
# Prepare weights for CUTLASS (no transpose needed)
|
||||||
|
w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K]
|
||||||
|
w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N]
|
||||||
|
|
||||||
|
# Create router scores and get topk
|
||||||
|
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||||
|
|
||||||
|
# WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization
|
||||||
|
# Force per-tensor quantization for all cases to match working e2e setup
|
||||||
|
a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||||
|
a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Force per-tensor quantization for all cases
|
||||||
|
per_act_token = False
|
||||||
|
|
||||||
|
# Create stride tensors for CUTLASS
|
||||||
|
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||||
|
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
|
||||||
|
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
|
||||||
|
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
def run_triton_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
a2_scale: torch.Tensor,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_cutlass_moe_fp8(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
a2_scale: torch.Tensor,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
|
||||||
|
cutlass_moe_fp8(
|
||||||
|
a=a,
|
||||||
|
w1_q=w1,
|
||||||
|
w2_q=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
ab_strides1=ab_strides1,
|
||||||
|
ab_strides2=ab_strides2,
|
||||||
|
c_strides1=c_strides1,
|
||||||
|
c_strides2=c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
activation="silu",
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-create quantization config to avoid creating it inside CUDA graph
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
|
||||||
|
cutlass_stream = torch.cuda.Stream()
|
||||||
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||||
|
# Capture 10 invocations like benchmark_moe.py
|
||||||
|
for _ in range(10):
|
||||||
|
cutlass_moe_fp8(
|
||||||
|
a=a,
|
||||||
|
w1_q=w1_fp8q_cutlass,
|
||||||
|
w2_q=w2_fp8q_cutlass,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
ab_strides1=ab_strides1,
|
||||||
|
ab_strides2=ab_strides2,
|
||||||
|
c_strides1=c_strides1,
|
||||||
|
c_strides2=c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
activation="silu",
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
|
||||||
|
triton_stream = torch.cuda.Stream()
|
||||||
|
triton_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
|
# Capture 10 invocations like benchmark_moe.py
|
||||||
|
for _ in range(10):
|
||||||
|
fused_experts(
|
||||||
|
a,
|
||||||
|
w1_fp8q,
|
||||||
|
w2_fp8q,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
|
||||||
|
"""Benchmark CUDA graph using events like benchmark_moe.py"""
|
||||||
|
# Warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies = []
|
||||||
|
for _ in range(num_iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
|
||||||
|
# Divide by 10 since graph contains 10 calls
|
||||||
|
return sum(latencies) / (num_iters * 10)
|
||||||
|
|
||||||
|
# Benchmark parameters
|
||||||
|
num_warmup = 5
|
||||||
|
num_iters = 100
|
||||||
|
|
||||||
|
# Benchmark only CUDA graphs (more reliable and faster)
|
||||||
|
# Benchmark Triton MoE with CUDA graphs
|
||||||
|
triton_graph_time = bench_cuda_graph(
|
||||||
|
triton_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Benchmark CUTLASS MoE with CUDA graphs
|
||||||
|
cutlass_graph_time = bench_cuda_graph(
|
||||||
|
cutlass_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert ms to us and return results
|
||||||
|
triton_time_us = triton_graph_time * 1000
|
||||||
|
cutlass_time_us = cutlass_graph_time * 1000
|
||||||
|
|
||||||
|
return {
|
||||||
|
"batch_size": m,
|
||||||
|
"triton_time_us": triton_time_us,
|
||||||
|
"cutlass_time_us": cutlass_time_us,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for tp in args.tp_sizes:
|
||||||
|
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||||
|
num_experts = layer[0]
|
||||||
|
topk = layer[1]
|
||||||
|
size_k = layer[2]
|
||||||
|
size_n = layer[3] // tp
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for per_act_token in args.per_act_token_opts:
|
||||||
|
for per_out_ch in args.per_out_ch_opts:
|
||||||
|
print(
|
||||||
|
f"\n=== {model}, experts={num_experts}, topk={topk},"
|
||||||
|
f"per_act={per_act_token}, per_out_ch={per_out_ch} ==="
|
||||||
|
)
|
||||||
|
|
||||||
|
config_results = []
|
||||||
|
for size_m in args.batch_sizes:
|
||||||
|
mkn = (size_m, size_k, size_n)
|
||||||
|
result = bench_run(
|
||||||
|
[], # Not used anymore
|
||||||
|
model,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
per_act_token,
|
||||||
|
per_out_ch,
|
||||||
|
mkn,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
config_results.append(result)
|
||||||
|
|
||||||
|
# Print results table for this configuration
|
||||||
|
if config_results:
|
||||||
|
print(
|
||||||
|
f"\n{'Batch Size':<12}"
|
||||||
|
f"{'Triton (us)':<15}"
|
||||||
|
f"{'CUTLASS (us)':<15}"
|
||||||
|
)
|
||||||
|
print("-" * 45)
|
||||||
|
for result in config_results:
|
||||||
|
print(
|
||||||
|
f"{result['batch_size']:<12}"
|
||||||
|
f"{result['triton_time_us']:<15.2f}"
|
||||||
|
f"{result['cutlass_time_us']:<15.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_results.extend(config_results)
|
||||||
|
|
||||||
|
print(f"\nTotal benchmarks completed: {len(all_results)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE
|
||||||
|
across specified models/shapes/batches
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
python benchmark_cutlass_moe_fp8.py \
|
||||||
|
--model "Llama-4-Maverick-17B-128E-Instruct-FP8" \
|
||||||
|
--tp-sizes 8 \
|
||||||
|
--batch-size 2 4 8 \
|
||||||
|
--per-act-token-opts false \
|
||||||
|
--per-out-ch-opts false
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-act-token-opts",
|
||||||
|
nargs="+",
|
||||||
|
type=lambda x: x.lower() == "true",
|
||||||
|
default=[False, True],
|
||||||
|
help="Per-activation token quantization options (true/false)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-out-ch-opts",
|
||||||
|
nargs="+",
|
||||||
|
type=lambda x: x.lower() == "true",
|
||||||
|
default=[False, True],
|
||||||
|
help="Per-output channel quantization options (true/false)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/layernorm.cpp"
|
"csrc/cpu/layernorm.cpp"
|
||||||
"csrc/cpu/mla_decode.cpp"
|
"csrc/cpu/mla_decode.cpp"
|
||||||
"csrc/cpu/pos_encoding.cpp"
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
"csrc/cpu/torch_bindings.cpp")
|
"csrc/cpu/torch_bindings.cpp"
|
||||||
|
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||||
|
|
||||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
|
|||||||
@ -135,10 +135,10 @@ public:
|
|||||||
max_splits = min(16, max_splits);
|
max_splits = min(16, max_splits);
|
||||||
|
|
||||||
// TODO: This avoids a hang when the batch size larger than 1 and
|
// TODO: This avoids a hang when the batch size larger than 1 and
|
||||||
// there is more than 4 kv_splits.
|
// there is more than 1 kv_splits.
|
||||||
// Discuss with NVIDIA how this can be fixed.
|
// Discuss with NVIDIA how this can be fixed.
|
||||||
if (B > 1) {
|
if (B > 1) {
|
||||||
max_splits = min(2, max_splits);
|
max_splits = min(1, max_splits);
|
||||||
}
|
}
|
||||||
|
|
||||||
// printf(" max_splits = %d\n", max_splits);
|
// printf(" max_splits = %d\n", max_splits);
|
||||||
|
|||||||
@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" int tp_rank, int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
|
||||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"dynamic_4bit_int_moe("
|
||||||
|
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
|
||||||
|
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
|
||||||
|
"int group_size, bool apply_router_weight_on_input, int activation_kind"
|
||||||
|
") -> Tensor");
|
||||||
|
|
||||||
|
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
|
||||||
|
|
||||||
// PagedAttention V2.
|
// PagedAttention V2.
|
||||||
ops.def(
|
ops.def(
|
||||||
"paged_attention_v2("
|
"paged_attention_v2("
|
||||||
|
|||||||
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/Parallel.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
// _dyn_quant_matmul_4bit is only available on AArch64.
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
|
||||||
|
int64_t group_size_eff, int64_t in_features,
|
||||||
|
int64_t out_features) {
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
|
||||||
|
in_features, out_features);
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
|
||||||
|
"_dyn_quant_matmul_4bit is unavailable on this architecture");
|
||||||
|
return {};
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ActivationKind : int64_t {
|
||||||
|
SwiGLU_Gu = 0, // act = SiLU(g) * u
|
||||||
|
SwiGLUOAI = 1, // act = SiLU(u) * g
|
||||||
|
SiLU = 2 // SiLU
|
||||||
|
};
|
||||||
|
|
||||||
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
|
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||||
|
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||||
|
int64_t activation_kind) {
|
||||||
|
TORCH_CHECK(x.dim() == 2, "x must be 2D");
|
||||||
|
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
|
||||||
|
"topk tensors must be [T, K]");
|
||||||
|
TORCH_CHECK(
|
||||||
|
w13_packed.size(0) == w2_packed.size(0),
|
||||||
|
"w13_packed and w2_packed must have same number of experts in dim 0");
|
||||||
|
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
|
||||||
|
|
||||||
|
const int64_t T = x.size(0);
|
||||||
|
const int64_t K = topk_ids.size(1);
|
||||||
|
const int64_t E = w13_packed.size(0);
|
||||||
|
const int64_t N = T * K;
|
||||||
|
|
||||||
|
auto x_c = x.contiguous();
|
||||||
|
auto ids_c = topk_ids.contiguous();
|
||||||
|
auto gates_c = topk_weights.to(at::kFloat).contiguous();
|
||||||
|
|
||||||
|
// bucketing tokens -> experts
|
||||||
|
c10::SmallVector<int64_t, 64> counts(
|
||||||
|
E, 0); // Small vector uses stack allocation
|
||||||
|
{
|
||||||
|
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||||
|
for (int64_t i = 0; i < N; ++i) {
|
||||||
|
const int64_t e_id = ids_ptr[i];
|
||||||
|
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
|
||||||
|
counts[e_id]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
|
||||||
|
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
|
||||||
|
|
||||||
|
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
|
||||||
|
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
|
||||||
|
{
|
||||||
|
c10::SmallVector<int64_t, 64> cursor(E, 0);
|
||||||
|
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||||
|
const auto* gts_ptr = gates_c.data_ptr<float>();
|
||||||
|
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
|
||||||
|
auto* gate_ptr = expert_gates.data_ptr<float>();
|
||||||
|
|
||||||
|
for (int64_t t = 0; t < T; ++t) {
|
||||||
|
const int64_t base = t * K;
|
||||||
|
for (int64_t k = 0; k < K; ++k) {
|
||||||
|
const int64_t idx = base + k;
|
||||||
|
const int64_t e = ids_ptr[idx];
|
||||||
|
const int64_t p = offsets[e] + (cursor[e]++);
|
||||||
|
tok_ptr[p] = t;
|
||||||
|
gate_ptr[p] = gts_ptr[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
|
||||||
|
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
|
||||||
|
|
||||||
|
// Per-expert outputs filled in parallel
|
||||||
|
std::vector<torch::Tensor> y_list(E);
|
||||||
|
y_list.resize(E);
|
||||||
|
|
||||||
|
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
|
||||||
|
for (int64_t e = e_begin; e < e_end; ++e) {
|
||||||
|
const int64_t te = counts[e];
|
||||||
|
if (te == 0) {
|
||||||
|
y_list[e] = at::empty({0, H}, x_c.options());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t start = offsets[e];
|
||||||
|
|
||||||
|
auto sel_tokens =
|
||||||
|
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||||
|
auto gates_e =
|
||||||
|
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||||
|
|
||||||
|
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
|
||||||
|
|
||||||
|
if (apply_router_weight_on_input) {
|
||||||
|
x_e = x_e.mul(gates_e.unsqueeze(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||||
|
auto w2_e = w2_packed.select(/*dim=*/0, e);
|
||||||
|
|
||||||
|
// W13
|
||||||
|
auto y13 =
|
||||||
|
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
|
||||||
|
|
||||||
|
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
|
||||||
|
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
|
||||||
|
|
||||||
|
torch::Tensor act;
|
||||||
|
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
|
||||||
|
constexpr double kAlpha = 1.702; // GPT-OSS default
|
||||||
|
constexpr double kLimit = 7.0; // GPT-OSS default
|
||||||
|
auto gate_c = at::clamp_max(g_part, kLimit);
|
||||||
|
auto up_c = at::clamp(u_part, -kLimit, kLimit);
|
||||||
|
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
|
||||||
|
act = up_c.add(1.0).mul(glu);
|
||||||
|
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
|
||||||
|
act = at::silu(g_part).mul(u_part);
|
||||||
|
}
|
||||||
|
|
||||||
|
// W2
|
||||||
|
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
|
||||||
|
|
||||||
|
if (!apply_router_weight_on_input) {
|
||||||
|
y = y.mul(gates_e.unsqueeze(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store per-expert result
|
||||||
|
y_list[e] = y;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Concatenate all expert outputs to match expert_tokens order
|
||||||
|
auto Y_all = at::cat(y_list, /*dim=*/0);
|
||||||
|
auto out = at::zeros({T, H}, x.options());
|
||||||
|
out =
|
||||||
|
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
|||||||
const std::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||||
|
|
||||||
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
|
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||||
|
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||||
|
int64_t activation_kind);
|
||||||
|
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||||
torch::Tensor& rank_data, int64_t rank,
|
torch::Tensor& rank_data, int64_t rank,
|
||||||
|
|||||||
@ -23,9 +23,14 @@
|
|||||||
typedef __hip_bfloat162 __nv_bfloat162;
|
typedef __hip_bfloat162 __nv_bfloat162;
|
||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
||||||
|
#if defined(HIP_FP8_TYPE_OCP)
|
||||||
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
||||||
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
||||||
|
#else
|
||||||
|
// ROCm 6.2 fallback: only *_fnuz types exist
|
||||||
|
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
|
||||||
|
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "core/registration.h"
|
#include "core/registration.h"
|
||||||
|
|||||||
@ -25,6 +25,12 @@
|
|||||||
#include "../attention/dtype_fp8.cuh"
|
#include "../attention/dtype_fp8.cuh"
|
||||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||||
|
|
||||||
|
// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent
|
||||||
|
#if !defined(HIP_FP8_TYPE_OCP)
|
||||||
|
using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz;
|
||||||
|
using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__HIPCC__) && \
|
#if defined(__HIPCC__) && \
|
||||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||||
#define __HIP__GFX9__
|
#define __HIP__GFX9__
|
||||||
|
|||||||
@ -34,7 +34,7 @@ Now supports 5 types of connectors:
|
|||||||
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
|
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
|
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'
|
||||||
```
|
```
|
||||||
|
|
||||||
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
|
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
|
||||||
|
|||||||
@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
|
|||||||
|
|
||||||
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
|
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
|
||||||
|
|
||||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
|
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'`
|
||||||
|
|
||||||
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
|
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ outlines_core == 0.2.11
|
|||||||
# required for outlines backend disk cache
|
# required for outlines backend disk cache
|
||||||
diskcache == 5.6.3
|
diskcache == 5.6.3
|
||||||
lark == 1.2.2
|
lark == 1.2.2
|
||||||
xgrammar == 0.1.24; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||||
typing_extensions >= 4.10
|
typing_extensions >= 4.10
|
||||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||||
partial-json-parser # used for parsing partial JSON outputs
|
partial-json-parser # used for parsing partial JSON outputs
|
||||||
|
|||||||
@ -1079,7 +1079,7 @@ def dummy_llava_path():
|
|||||||
local_dir=_dummy_llava_path,
|
local_dir=_dummy_llava_path,
|
||||||
ignore_patterns=[
|
ignore_patterns=[
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||||
"*.msgpack"
|
"*.msgpack", "*.safetensors"
|
||||||
])
|
])
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
@ -1098,7 +1098,7 @@ def dummy_gemma2_embedding_path():
|
|||||||
local_dir=_dummy_gemma2_embedding_path,
|
local_dir=_dummy_gemma2_embedding_path,
|
||||||
ignore_patterns=[
|
ignore_patterns=[
|
||||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||||
"*.msgpack"
|
"*.msgpack", "*.safetensors"
|
||||||
])
|
])
|
||||||
assert os.path.exists(json_path)
|
assert os.path.exists(json_path)
|
||||||
with open(json_path) as f:
|
with open(json_path) as f:
|
||||||
|
|||||||
@ -382,7 +382,6 @@ def test_tp_language_generation(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
pytest.skip("Skipping the test until V1 passes it.")
|
|
||||||
_compare_tp(model_id,
|
_compare_tp(model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
@ -410,7 +409,6 @@ def test_tp_language_embedding(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
pytest.skip("Skipping the test until V1 passes it.")
|
|
||||||
_compare_tp(model_id,
|
_compare_tp(model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
@ -438,7 +436,6 @@ def test_tp_multimodal_generation(
|
|||||||
test_options: PPTestOptions,
|
test_options: PPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
):
|
):
|
||||||
pytest.skip("Skipping the test until V1 passes it.")
|
|
||||||
_compare_tp(model_id,
|
_compare_tp(model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
|
|||||||
@ -523,6 +523,7 @@ async def test_function_calling(client: OpenAI, model_name: str):
|
|||||||
input="What's the weather like in Paris today?",
|
input="What's the weather like in Paris today?",
|
||||||
tools=tools,
|
tools=tools,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
extra_body={"request_id": "test_function_calling_non_resp"},
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.status == "completed"
|
assert response.status == "completed"
|
||||||
|
|||||||
@ -194,6 +194,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
|||||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||||
args1 = tc.function.arguments
|
args1 = tc.function.arguments
|
||||||
assert args1 is not None and len(args1) > 0
|
assert args1 is not None and len(args1) > 0
|
||||||
|
assert not first_msg.content
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": args1})
|
messages.append({"role": "assistant", "content": args1})
|
||||||
messages.append({
|
messages.append({
|
||||||
|
|||||||
@ -85,8 +85,7 @@ def test_env(
|
|||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, None, block_size,
|
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||||
False)
|
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
@ -106,7 +105,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
assert f"The selected backend, {name}" in str(
|
assert f"The selected backend, {name}" in str(
|
||||||
exc_info.value)
|
exc_info.value)
|
||||||
@ -117,7 +115,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
assert f"The selected backend, {name}" in str(
|
assert f"The selected backend, {name}" in str(
|
||||||
exc_info.value)
|
exc_info.value)
|
||||||
@ -127,7 +124,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = f"{name}_VLLM_V1"
|
expected = f"{name}_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -136,7 +132,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "TRITON_ATTN_VLLM_V1"
|
expected = "TRITON_ATTN_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -164,7 +159,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "CUTLASS_MLA_VLLM_V1"
|
expected = "CUTLASS_MLA_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -179,7 +173,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "FLASHINFER_MLA"
|
expected = "FLASHINFER_MLA"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -199,7 +192,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = f"{name}_VLLM_V1"
|
expected = f"{name}_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -208,7 +200,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "FLASH_ATTN_MLA"
|
expected = "FLASH_ATTN_MLA"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -218,7 +209,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "TRITON_MLA_VLLM_V1"
|
expected = "TRITON_MLA_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -227,7 +217,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "FLASHINFER_VLLM_V1"
|
expected = "FLASHINFER_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -236,7 +225,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "FLASH_ATTN_VLLM_V1"
|
expected = "FLASH_ATTN_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
@ -245,7 +233,6 @@ def test_env(
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
None,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
assert backend.get_name() == "FLEX_ATTENTION", (
|
assert backend.get_name() == "FLEX_ATTENTION", (
|
||||||
"Should fallback to FlexAttention if head size is "
|
"Should fallback to FlexAttention if head size is "
|
||||||
@ -264,13 +251,13 @@ def test_fp32_fallback(
|
|||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CudaPlatform()):
|
CudaPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||||
assert backend.get_name() == "FLEX_ATTENTION"
|
assert backend.get_name() == "FLEX_ATTENTION"
|
||||||
|
|
||||||
|
|
||||||
@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
monkeypatch.setattr(torch.cuda,
|
monkeypatch.setattr(torch.cuda,
|
||||||
"get_device_capability",
|
"get_device_capability",
|
||||||
lambda _=None: (7, 5))
|
lambda _=None: (7, 5))
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Reset the monkeypatch for subsequent tests
|
# Reset the monkeypatch for subsequent tests
|
||||||
monkeypatch.undo()
|
monkeypatch.undo()
|
||||||
|
|
||||||
# Unsupported data type
|
# Unsupported data type
|
||||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported kv cache data type
|
# Unsupported kv cache data type
|
||||||
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported block size
|
# Unsupported block size
|
||||||
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# flash-attn is not installed
|
# flash-attn is not installed
|
||||||
import sys
|
import sys
|
||||||
original_module = sys.modules.get('vllm_flash_attn')
|
original_module = sys.modules.get('vllm_flash_attn')
|
||||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
|
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Restore the original module if it existed
|
# Restore the original module if it existed
|
||||||
@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
|
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
|
||||||
|
|
||||||
# Unsupported head size
|
# Unsupported head size
|
||||||
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
||||||
|
|
||||||
# Attention-free models should bypass env and use PlaceholderAttention
|
|
||||||
backend = get_attn_backend(16, torch.float16, None, 16, True)
|
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
|
|
||||||
@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
# Should raise ValueError for invalid backend
|
# Should raise ValueError for invalid backend
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_attn_backend(32, torch.float16, None, 16, False)
|
get_attn_backend(32, torch.float16, None, 16)
|
||||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
||||||
|
|||||||
@ -12,11 +12,11 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||||
from vllm.multimodal.inputs import MultiModalInputs
|
from vllm.multimodal.inputs import MultiModalInputs
|
||||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
InputProcessingContext)
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
cached_tokenizer_from_config,
|
cached_tokenizer_from_config,
|
||||||
encode_tokens)
|
encode_tokens)
|
||||||
|
|||||||
@ -18,10 +18,10 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
|||||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.inputs import InputProcessingContext
|
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
InputProcessingContext)
|
||||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|||||||
@ -42,7 +42,6 @@ def test_oot_registration_text_generation(
|
|||||||
assert rest == ""
|
assert rest == ""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="This test is skipped because it failed on V1.")
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_oot_registration_embedding(
|
def test_oot_registration_embedding(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@ -63,7 +62,6 @@ def test_oot_registration_embedding(
|
|||||||
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
|
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="This test is skipped because it failed on V1.")
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_oot_registration_multimodal(
|
def test_oot_registration_multimodal(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
|||||||
@ -11,8 +11,9 @@ import torch.nn.functional as F
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
||||||
from vllm.inputs import InputContext
|
|
||||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||||
|
from vllm.multimodal.processing import InputProcessingContext
|
||||||
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
|
||||||
from .registry import HF_EXAMPLE_MODELS
|
from .registry import HF_EXAMPLE_MODELS
|
||||||
|
|
||||||
@ -264,7 +265,7 @@ def build_model_context(
|
|||||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||||
mm_processor_cache_gb: int = 0,
|
mm_processor_cache_gb: int = 0,
|
||||||
):
|
):
|
||||||
"""Creates an InputContext for a given model.
|
"""Creates an InputProcessingContext for a given model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: ID of the model being considered.
|
model_id: ID of the model being considered.
|
||||||
@ -273,7 +274,7 @@ def build_model_context(
|
|||||||
limit_mm_per_prompt: Multimodal limits.
|
limit_mm_per_prompt: Multimodal limits.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
InputContext for the model being considered.
|
InputProcessingContext for the model being considered.
|
||||||
"""
|
"""
|
||||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||||
model_info.check_available_online(on_fail="skip")
|
model_info.check_available_online(on_fail="skip")
|
||||||
@ -298,7 +299,11 @@ def build_model_context(
|
|||||||
enforce_eager=model_info.enforce_eager,
|
enforce_eager=model_info.enforce_eager,
|
||||||
**model_config_kwargs,
|
**model_config_kwargs,
|
||||||
)
|
)
|
||||||
return InputContext(model_config)
|
|
||||||
|
return InputProcessingContext(
|
||||||
|
model_config,
|
||||||
|
tokenizer=cached_tokenizer_from_config(model_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_embeddings_close(
|
def check_embeddings_close(
|
||||||
|
|||||||
@ -8,11 +8,11 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
from vllm.multimodal.processing import (InputProcessingContext,
|
||||||
|
PlaceholderFeaturesInfo,
|
||||||
PromptIndexTargets, PromptInsertion,
|
PromptIndexTargets, PromptInsertion,
|
||||||
PromptReplacement, apply_text_matches,
|
PromptReplacement, apply_text_matches,
|
||||||
apply_token_matches,
|
apply_token_matches,
|
||||||
|
|||||||
392
tests/reasoning/test_base_thinking_reasoning_parser.py
Normal file
392
tests/reasoning/test_base_thinking_reasoning_parser.py
Normal file
@ -0,0 +1,392 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
|
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||||
|
|
||||||
|
|
||||||
|
# Create a concrete test implementation of BaseThinkingReasoningParser
|
||||||
|
class TestThinkingReasoningParser(BaseThinkingReasoningParser):
|
||||||
|
"""Test implementation of BaseThinkingReasoningParser."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
return "<test:think>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
return "</test:think>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser):
|
||||||
|
"""Alternative test implementation with different tokens."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
return "<alt:start>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
return "<alt:end>"
|
||||||
|
|
||||||
|
|
||||||
|
# Use a test model
|
||||||
|
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def test_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||||
|
# Add custom test tokens
|
||||||
|
test_tokens = ["<test:think>", "</test:think>", "<alt:start>", "<alt:end>"]
|
||||||
|
existing_tokens = set(tokenizer.get_vocab().keys())
|
||||||
|
new_tokens = [
|
||||||
|
token for token in test_tokens if token not in existing_tokens
|
||||||
|
]
|
||||||
|
if new_tokens:
|
||||||
|
tokenizer.add_tokens(new_tokens)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseThinkingReasoningParserInit:
|
||||||
|
"""
|
||||||
|
Test initialization and basic properties of
|
||||||
|
BaseThinkingReasoningParser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_successful_initialization(self, test_tokenizer):
|
||||||
|
"""Test successful initialization with valid tokens."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
assert parser.start_token == "<test:think>"
|
||||||
|
assert parser.end_token == "</test:think>"
|
||||||
|
assert parser.start_token_id is not None
|
||||||
|
assert parser.end_token_id is not None
|
||||||
|
|
||||||
|
def test_initialization_with_missing_tokenizer(self):
|
||||||
|
"""Test that initialization fails without tokenizer."""
|
||||||
|
with pytest.raises(ValueError, match="model tokenizer must be passed"):
|
||||||
|
TestThinkingReasoningParser(None)
|
||||||
|
|
||||||
|
def test_initialization_with_missing_tokens(self, test_tokenizer):
|
||||||
|
"""Test that initialization fails when tokens are not in vocabulary."""
|
||||||
|
|
||||||
|
# Create a parser with tokens not in vocabulary
|
||||||
|
class MissingTokenParser(BaseThinkingReasoningParser):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
return "<missing:start>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
return "<missing:end>"
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError,
|
||||||
|
match="could not locate think start/end tokens"):
|
||||||
|
MissingTokenParser(test_tokenizer)
|
||||||
|
|
||||||
|
def test_initialization_with_empty_tokens(self, test_tokenizer):
|
||||||
|
"""Test that initialization fails with empty token strings."""
|
||||||
|
|
||||||
|
class EmptyTokenParser(BaseThinkingReasoningParser):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="start_token and end_token must be defined"):
|
||||||
|
EmptyTokenParser(test_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseThinkingReasoningParserMethods:
|
||||||
|
"""Test the methods of BaseThinkingReasoningParser."""
|
||||||
|
|
||||||
|
def test_is_reasoning_end(self, test_tokenizer):
|
||||||
|
"""Test the is_reasoning_end method."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
end_token_id = parser.end_token_id
|
||||||
|
|
||||||
|
# Test with end token present
|
||||||
|
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
|
||||||
|
|
||||||
|
# Test without end token
|
||||||
|
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
|
||||||
|
|
||||||
|
# Test with empty list
|
||||||
|
assert parser.is_reasoning_end([]) is False
|
||||||
|
|
||||||
|
def test_extract_content_ids(self, test_tokenizer):
|
||||||
|
"""Test the extract_content_ids method."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
end_token_id = parser.end_token_id
|
||||||
|
|
||||||
|
# Test with end token in the middle
|
||||||
|
input_ids = [1, 2, end_token_id, 4, 5]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == [4, 5]
|
||||||
|
|
||||||
|
# Test with end token at the end
|
||||||
|
input_ids = [1, 2, 3, end_token_id]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == []
|
||||||
|
|
||||||
|
# Test without end token
|
||||||
|
input_ids = [1, 2, 3, 4]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == []
|
||||||
|
|
||||||
|
# Test with end token as last element (should not extract)
|
||||||
|
input_ids = [1, 2, 3, end_token_id]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseThinkingReasoningParserExtraction:
|
||||||
|
"""Test reasoning content extraction methods."""
|
||||||
|
|
||||||
|
def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer):
|
||||||
|
"""Test extraction when both start and end tokens are present."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
|
||||||
|
model_output = ("<test:think>This is reasoning"
|
||||||
|
"</test:think>This is content")
|
||||||
|
reasoning, content = parser.extract_reasoning_content(
|
||||||
|
model_output, request)
|
||||||
|
|
||||||
|
assert reasoning == "This is reasoning"
|
||||||
|
assert content == "This is content"
|
||||||
|
|
||||||
|
def test_extract_reasoning_content_only_end_token(self, test_tokenizer):
|
||||||
|
"""Test extraction when only end token is present."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
|
||||||
|
model_output = ("This is reasoning</test:think>This is content")
|
||||||
|
reasoning, content = parser.extract_reasoning_content(
|
||||||
|
model_output, request)
|
||||||
|
|
||||||
|
assert reasoning == "This is reasoning"
|
||||||
|
assert content == "This is content"
|
||||||
|
|
||||||
|
def test_extract_reasoning_content_no_end_token(self, test_tokenizer):
|
||||||
|
"""Test extraction when no end token is present."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
|
||||||
|
model_output = "This is just content"
|
||||||
|
reasoning, content = parser.extract_reasoning_content(
|
||||||
|
model_output, request)
|
||||||
|
|
||||||
|
assert reasoning == "This is just content"
|
||||||
|
assert content is None
|
||||||
|
|
||||||
|
def test_extract_reasoning_content_empty_output(self, test_tokenizer):
|
||||||
|
"""Test extraction with empty output."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
|
||||||
|
model_output = ""
|
||||||
|
reasoning, content = parser.extract_reasoning_content(
|
||||||
|
model_output, request)
|
||||||
|
|
||||||
|
assert reasoning == ""
|
||||||
|
assert content is None
|
||||||
|
|
||||||
|
def test_extract_reasoning_content_only_tokens(self, test_tokenizer):
|
||||||
|
"""Test extraction with only tokens and no content."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||||
|
|
||||||
|
model_output = ("<test:think></test:think>")
|
||||||
|
reasoning, content = parser.extract_reasoning_content(
|
||||||
|
model_output, request)
|
||||||
|
|
||||||
|
assert reasoning == ""
|
||||||
|
assert content is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseThinkingReasoningParserStreaming:
|
||||||
|
"""Test streaming functionality of BaseThinkingReasoningParser."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_simple_reasoning_extraction(self, test_tokenizer, streaming):
|
||||||
|
"""
|
||||||
|
Test basic reasoning extraction in both
|
||||||
|
streaming and non-streaming modes.
|
||||||
|
"""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
model_output = [
|
||||||
|
"<test:think>", "Some ", "reasoning ", "content", "</test:think>",
|
||||||
|
"Final ", "answer"
|
||||||
|
]
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
model_output,
|
||||||
|
streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == "Some reasoning content"
|
||||||
|
assert content == "Final answer"
|
||||||
|
|
||||||
|
def test_streaming_with_incremental_deltas(self, test_tokenizer):
|
||||||
|
"""Test streaming processing with small incremental deltas."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
deltas = [
|
||||||
|
"<test:think>",
|
||||||
|
"Some ",
|
||||||
|
"reasoning ",
|
||||||
|
"content",
|
||||||
|
"</test:think>",
|
||||||
|
"Final ",
|
||||||
|
"answer",
|
||||||
|
]
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
deltas,
|
||||||
|
streaming=True)
|
||||||
|
|
||||||
|
assert reasoning == "Some reasoning content"
|
||||||
|
assert content == "Final answer"
|
||||||
|
|
||||||
|
def test_streaming_with_start_token(self, test_tokenizer):
|
||||||
|
"""Test streaming with start token included."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
deltas = [
|
||||||
|
"<test:think>",
|
||||||
|
"Some ",
|
||||||
|
"reasoning",
|
||||||
|
"</test:think>",
|
||||||
|
"Answer",
|
||||||
|
]
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
deltas,
|
||||||
|
streaming=True)
|
||||||
|
|
||||||
|
assert reasoning == "Some reasoning"
|
||||||
|
assert content == "Answer"
|
||||||
|
|
||||||
|
def test_streaming_no_end_token(self, test_tokenizer):
|
||||||
|
"""Test streaming when no end token is encountered."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
deltas = [
|
||||||
|
"<test:think>",
|
||||||
|
"Some ",
|
||||||
|
"reasoning ",
|
||||||
|
"without ",
|
||||||
|
"end",
|
||||||
|
]
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
deltas,
|
||||||
|
streaming=True)
|
||||||
|
|
||||||
|
assert reasoning == "Some reasoning without end"
|
||||||
|
assert content is None
|
||||||
|
|
||||||
|
def test_streaming_only_end_token(self, test_tokenizer):
|
||||||
|
"""Test streaming when only end token appears."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
deltas = [
|
||||||
|
"<test:think>",
|
||||||
|
"Reasoning ",
|
||||||
|
"content",
|
||||||
|
"</test:think>",
|
||||||
|
"Final",
|
||||||
|
]
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
deltas,
|
||||||
|
streaming=True)
|
||||||
|
|
||||||
|
assert reasoning == "Reasoning content"
|
||||||
|
assert content == "Final"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseThinkingReasoningParserMultipleImplementations:
|
||||||
|
"""
|
||||||
|
Test that multiple implementations of
|
||||||
|
BaseThinkingReasoningParser work correctly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_different_token_implementations(self, test_tokenizer):
|
||||||
|
"""
|
||||||
|
Test that different implementations
|
||||||
|
with different tokens work independently.
|
||||||
|
"""
|
||||||
|
parser1 = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
parser2 = TestThinkingReasoningParserAlt(test_tokenizer)
|
||||||
|
|
||||||
|
# Test parser1
|
||||||
|
model_output1 = ("Reasoning1</test:think>Content1")
|
||||||
|
reasoning1, content1 = run_reasoning_extraction(
|
||||||
|
parser1, [model_output1])
|
||||||
|
assert reasoning1 == "Reasoning1"
|
||||||
|
assert content1 == "Content1"
|
||||||
|
|
||||||
|
# Test parser2
|
||||||
|
model_output2 = "Reasoning2<alt:end>Content2"
|
||||||
|
reasoning2, content2 = run_reasoning_extraction(
|
||||||
|
parser2, [model_output2])
|
||||||
|
assert reasoning2 == "Reasoning2"
|
||||||
|
assert content2 == "Content2"
|
||||||
|
|
||||||
|
# Verify tokens are different
|
||||||
|
assert parser1.start_token != parser2.start_token
|
||||||
|
assert parser1.end_token != parser2.end_token
|
||||||
|
assert parser1.start_token_id != parser2.start_token_id
|
||||||
|
assert parser1.end_token_id != parser2.end_token_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseThinkingReasoningParserEdgeCases:
|
||||||
|
"""Test edge cases and error conditions."""
|
||||||
|
|
||||||
|
def test_multiple_end_tokens(self, test_tokenizer):
|
||||||
|
"""Test behavior with multiple end tokens."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
model_output = ("First</test:think>Middle</test:think>Last")
|
||||||
|
reasoning, content = run_reasoning_extraction(parser, [model_output])
|
||||||
|
|
||||||
|
# Should stop at first end token
|
||||||
|
assert reasoning == "First"
|
||||||
|
assert content == "Middle</test:think>Last"
|
||||||
|
|
||||||
|
def test_nested_tokens(self, test_tokenizer):
|
||||||
|
"""Test behavior with nested-like token patterns."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
model_output = ("<test:think>Outer"
|
||||||
|
"<test:think>Inner</test:think>Content")
|
||||||
|
reasoning, content = run_reasoning_extraction(parser, [model_output])
|
||||||
|
|
||||||
|
# Should process normally, start from first start token
|
||||||
|
assert reasoning == "Outer<test:think>Inner"
|
||||||
|
assert content == "Content"
|
||||||
|
|
||||||
|
def test_malformed_tokens(self, test_tokenizer):
|
||||||
|
"""Test behavior with malformed token-like strings."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|
||||||
|
model_output = ("<test:thinking>Not a real token"
|
||||||
|
"</test:thinking>Content")
|
||||||
|
reasoning, content = run_reasoning_extraction(parser, [model_output])
|
||||||
|
|
||||||
|
# Should treat as regular content since tokens don't match exactly
|
||||||
|
assert reasoning == ("<test:thinking>Not a real token"
|
||||||
|
"</test:thinking>Content")
|
||||||
|
assert content is None
|
||||||
237
tests/reasoning/test_seedoss_reasoning_parser.py
Normal file
237
tests/reasoning/test_seedoss_reasoning_parser.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
parser_name = "seed_oss"
|
||||||
|
start_token = "<seed:think>"
|
||||||
|
end_token = "</seed:think>"
|
||||||
|
|
||||||
|
# Use a test model that contains our custom tokens
|
||||||
|
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def seedoss_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||||
|
# Add custom SeedOSS tokens if they don't exist
|
||||||
|
if start_token not in tokenizer.get_vocab():
|
||||||
|
tokenizer.add_tokens([start_token, end_token])
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
SIMPLE_REASONING: dict[str, Any] = {
|
||||||
|
"output": "This is a reasoning section</seed:think>This is the rest",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
COMPLETE_REASONING: dict[str, Any] = {
|
||||||
|
"output": "This is a reasoning section</seed:think>",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
NO_CONTENT: dict[str, Any] = {
|
||||||
|
"output": "This is content",
|
||||||
|
"reasoning_content": "This is content",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
NO_REASONING_STREAMING: dict[str, Any] = {
|
||||||
|
"output": "This is a reasoning section",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
MULTIPLE_LINES: dict[str, Any] = {
|
||||||
|
"output": "This\nThat</seed:think>This is the rest\nThat",
|
||||||
|
"reasoning_content": "This\nThat",
|
||||||
|
"content": "This is the rest\nThat",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
WITH_START_TOKEN: dict[str, Any] = {
|
||||||
|
"output": ("<seed:think>This is a reasoning section"
|
||||||
|
"</seed:think>This is the rest"),
|
||||||
|
"reasoning_content":
|
||||||
|
"This is a reasoning section",
|
||||||
|
"content":
|
||||||
|
"This is the rest",
|
||||||
|
"is_reasoning_end":
|
||||||
|
True,
|
||||||
|
}
|
||||||
|
ONLY_END_TOKEN: dict[str, Any] = {
|
||||||
|
"output": "Some reasoning</seed:think>This is the rest",
|
||||||
|
"reasoning_content": "Some reasoning",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
NO_TOKENS: dict[str, Any] = {
|
||||||
|
"output": "This is just content without any reasoning tokens",
|
||||||
|
"reasoning_content": "This is just content without any reasoning tokens",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedoss_reasoning_parser_creation(seedoss_tokenizer):
|
||||||
|
"""Test that the SeedOSS reasoning parser can be created and registered."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
assert isinstance(parser, ReasoningParser)
|
||||||
|
assert parser.start_token == start_token
|
||||||
|
assert parser.end_token == end_token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_simple_reasoning(seedoss_tokenizer, streaming):
|
||||||
|
"""Test basic reasoning extraction with both tokens."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == SIMPLE_REASONING["reasoning_content"]
|
||||||
|
assert content == SIMPLE_REASONING["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_complete_reasoning(seedoss_tokenizer, streaming):
|
||||||
|
"""Test reasoning extraction when there's no content after reasoning."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == COMPLETE_REASONING["reasoning_content"]
|
||||||
|
assert content == COMPLETE_REASONING["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_no_content(seedoss_tokenizer, streaming):
|
||||||
|
"""Test when there's no end token - everything is reasoning content."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, NO_CONTENT["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == NO_CONTENT["reasoning_content"]
|
||||||
|
assert content == NO_CONTENT["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_multiple_lines(seedoss_tokenizer, streaming):
|
||||||
|
"""Test reasoning extraction with multiline content."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == MULTIPLE_LINES["reasoning_content"]
|
||||||
|
assert content == MULTIPLE_LINES["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_with_start_token(seedoss_tokenizer, streaming):
|
||||||
|
"""Test reasoning extraction with both start and end tokens."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == WITH_START_TOKEN["reasoning_content"]
|
||||||
|
assert content == WITH_START_TOKEN["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_only_end_token(seedoss_tokenizer, streaming):
|
||||||
|
"""
|
||||||
|
Test reasoning extraction with only end token
|
||||||
|
(SeedOSS typical behavior).
|
||||||
|
"""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == ONLY_END_TOKEN["reasoning_content"]
|
||||||
|
assert content == ONLY_END_TOKEN["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_no_tokens(seedoss_tokenizer, streaming):
|
||||||
|
"""Test when there are no reasoning tokens at all."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, [cast(str, NO_TOKENS["output"])], streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == NO_TOKENS["reasoning_content"]
|
||||||
|
assert content == NO_TOKENS["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_reasoning_end(seedoss_tokenizer):
|
||||||
|
"""Test the is_reasoning_end method."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
# Test with end token present
|
||||||
|
end_token_id = parser.end_token_id
|
||||||
|
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
|
||||||
|
|
||||||
|
# Test without end token
|
||||||
|
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_content_ids(seedoss_tokenizer):
|
||||||
|
"""Test the extract_content_ids method."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
end_token_id = parser.end_token_id
|
||||||
|
|
||||||
|
# Test with end token in the middle
|
||||||
|
input_ids = [1, 2, end_token_id, 4, 5]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == [4, 5]
|
||||||
|
|
||||||
|
# Test with end token at the end
|
||||||
|
input_ids = [1, 2, 3, end_token_id]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == []
|
||||||
|
|
||||||
|
# Test without end token
|
||||||
|
input_ids = [1, 2, 3, 4]
|
||||||
|
content_ids = parser.extract_content_ids(input_ids)
|
||||||
|
assert content_ids == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_delta_processing(seedoss_tokenizer):
|
||||||
|
"""Test streaming processing with small deltas."""
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||||
|
parser = parser_cls(seedoss_tokenizer)
|
||||||
|
|
||||||
|
# Test streaming with incremental tokens
|
||||||
|
deltas = [
|
||||||
|
"Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer"
|
||||||
|
]
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
deltas,
|
||||||
|
streaming=True)
|
||||||
|
|
||||||
|
assert reasoning == "Some reasoning content"
|
||||||
|
assert content == "Final answer"
|
||||||
@ -1,7 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -388,3 +390,108 @@ def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
|
|||||||
else:
|
else:
|
||||||
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
|
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
|
||||||
assert actual_max_len == expected_max_len
|
assert actual_max_len == expected_max_len
|
||||||
|
|
||||||
|
|
||||||
|
class MockConfig:
|
||||||
|
"""Simple mock object for testing maybe_pull_model_tokenizer_for_runai"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, tokenizer: str):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.model_weights = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("s3_url", [
|
||||||
|
"s3://example-bucket-1/model/",
|
||||||
|
"s3://example-bucket-2/model/",
|
||||||
|
])
|
||||||
|
@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files')
|
||||||
|
def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url):
|
||||||
|
"""Test that S3 URLs create deterministic local directories for model and
|
||||||
|
tokenizer."""
|
||||||
|
# Mock pull_files to avoid actually downloading files during tests
|
||||||
|
mock_pull_files.return_value = None
|
||||||
|
|
||||||
|
# Create first mock and run the method
|
||||||
|
config1 = MockConfig(model=s3_url, tokenizer=s3_url)
|
||||||
|
ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url)
|
||||||
|
|
||||||
|
# Check that model and tokenizer point to existing directories
|
||||||
|
assert os.path.exists(
|
||||||
|
config1.model), f"Model directory does not exist: {config1.model}"
|
||||||
|
assert os.path.isdir(
|
||||||
|
config1.model), f"Model path is not a directory: {config1.model}"
|
||||||
|
assert os.path.exists(
|
||||||
|
config1.tokenizer
|
||||||
|
), f"Tokenizer directory does not exist: {config1.tokenizer}"
|
||||||
|
assert os.path.isdir(
|
||||||
|
config1.tokenizer
|
||||||
|
), f"Tokenizer path is not a directory: {config1.tokenizer}"
|
||||||
|
|
||||||
|
# Verify that the paths are different from the original S3 URL
|
||||||
|
assert config1.model != s3_url, (
|
||||||
|
"Model path should be converted to local directory")
|
||||||
|
assert config1.tokenizer != s3_url, (
|
||||||
|
"Tokenizer path should be converted to local directory")
|
||||||
|
|
||||||
|
# Store the original paths
|
||||||
|
created_model_dir = config1.model
|
||||||
|
create_tokenizer_dir = config1.tokenizer
|
||||||
|
|
||||||
|
# Create a new mock and run the method with the same S3 URL
|
||||||
|
config2 = MockConfig(model=s3_url, tokenizer=s3_url)
|
||||||
|
ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url)
|
||||||
|
|
||||||
|
# Check that the new directories exist
|
||||||
|
assert os.path.exists(
|
||||||
|
config2.model), f"Model directory does not exist: {config2.model}"
|
||||||
|
assert os.path.isdir(
|
||||||
|
config2.model), f"Model path is not a directory: {config2.model}"
|
||||||
|
assert os.path.exists(
|
||||||
|
config2.tokenizer
|
||||||
|
), f"Tokenizer directory does not exist: {config2.tokenizer}"
|
||||||
|
assert os.path.isdir(
|
||||||
|
config2.tokenizer
|
||||||
|
), f"Tokenizer path is not a directory: {config2.tokenizer}"
|
||||||
|
|
||||||
|
# Verify that the paths are deterministic (same as before)
|
||||||
|
assert config2.model == created_model_dir, (
|
||||||
|
f"Model paths are not deterministic. "
|
||||||
|
f"Original: {created_model_dir}, New: {config2.model}")
|
||||||
|
assert config2.tokenizer == create_tokenizer_dir, (
|
||||||
|
f"Tokenizer paths are not deterministic. "
|
||||||
|
f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}")
|
||||||
|
|
||||||
|
|
||||||
|
@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files')
|
||||||
|
def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||||
|
"""Test that different S3 URLs create different local directories."""
|
||||||
|
# Mock pull_files to avoid actually downloading files during tests
|
||||||
|
mock_pull_files.return_value = None
|
||||||
|
|
||||||
|
s3_url1 = "s3://example-bucket-1/model/"
|
||||||
|
s3_url2 = "s3://example-bucket-2/model/"
|
||||||
|
|
||||||
|
# Create mocks with different S3 URLs and run the method
|
||||||
|
config1 = MockConfig(model=s3_url1, tokenizer=s3_url1)
|
||||||
|
ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1)
|
||||||
|
|
||||||
|
config2 = MockConfig(model=s3_url2, tokenizer=s3_url2)
|
||||||
|
ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2)
|
||||||
|
|
||||||
|
# Verify that different URLs produce different directories
|
||||||
|
assert config1.model != config2.model, (
|
||||||
|
f"Different S3 URLs should create different model directories. "
|
||||||
|
f"URL1 model: {config1.model}, URL2 model: {config2.model}")
|
||||||
|
assert config1.tokenizer != config2.tokenizer, (
|
||||||
|
f"Different S3 URLs should create different tokenizer directories. "
|
||||||
|
f"URL1 tokenizer: {config1.tokenizer}, "
|
||||||
|
f"URL2 tokenizer: {config2.tokenizer}")
|
||||||
|
|
||||||
|
# Verify that both sets of directories exist
|
||||||
|
assert os.path.exists(config1.model) and os.path.isdir(config1.model)
|
||||||
|
assert os.path.exists(config1.tokenizer) and os.path.isdir(
|
||||||
|
config1.tokenizer)
|
||||||
|
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
|
||||||
|
assert os.path.exists(config2.tokenizer) and os.path.isdir(
|
||||||
|
config2.tokenizer)
|
||||||
|
|||||||
54
tests/tool_use/test_deepseekv31_tool_parser.py
Normal file
54
tests/tool_use/test_deepseekv31_tool_parser.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import DeepSeekV31ToolParser
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
MODEL = "deepseek-ai/DeepSeek-V3.1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def deepseekv31_tokenizer():
|
||||||
|
return get_tokenizer(tokenizer_name=MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parser(deepseekv31_tokenizer):
|
||||||
|
return DeepSeekV31ToolParser(deepseekv31_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_tool(parser):
|
||||||
|
model_output = (
|
||||||
|
"normal text" + "<|tool▁calls▁begin|>" +
|
||||||
|
"<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" +
|
||||||
|
"<|tool▁calls▁end|>")
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
assert result.tools_called
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].function.name == "foo"
|
||||||
|
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
|
||||||
|
assert result.content == "normal text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_multiple_tools(parser):
|
||||||
|
model_output = (
|
||||||
|
"some prefix text" + "<|tool▁calls▁begin|>" +
|
||||||
|
"<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" +
|
||||||
|
"<|tool▁call▁begin|>bar<|tool▁sep|>{\"y\":2}<|tool▁call▁end|>" +
|
||||||
|
"<|tool▁calls▁end|>" + " some suffix text")
|
||||||
|
|
||||||
|
result = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert result.tools_called
|
||||||
|
assert len(result.tool_calls) == 2
|
||||||
|
|
||||||
|
assert result.tool_calls[0].function.name == "foo"
|
||||||
|
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
|
||||||
|
|
||||||
|
assert result.tool_calls[1].function.name == "bar"
|
||||||
|
assert result.tool_calls[1].function.arguments == "{\"y\":2}"
|
||||||
|
|
||||||
|
# prefix is content
|
||||||
|
assert result.content == "some prefix text"
|
||||||
@ -70,7 +70,12 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
|||||||
assert extracted_info.content == "This is a test"
|
assert extracted_info.content == "This is a test"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
@pytest.mark.parametrize("tool_args", [
|
||||||
|
'{"location": "Tokyo"}',
|
||||||
|
'{\n"location": "Tokyo"\n}',
|
||||||
|
])
|
||||||
|
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
|
||||||
|
tool_args):
|
||||||
convo = Conversation.from_messages([
|
convo = Conversation.from_messages([
|
||||||
Message.from_role_and_content(Role.USER,
|
Message.from_role_and_content(Role.USER,
|
||||||
"What is the weather in Tokyo?"),
|
"What is the weather in Tokyo?"),
|
||||||
@ -80,7 +85,7 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
|||||||
).with_channel("analysis"),
|
).with_channel("analysis"),
|
||||||
Message.from_role_and_content(
|
Message.from_role_and_content(
|
||||||
Role.ASSISTANT,
|
Role.ASSISTANT,
|
||||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
tool_args).with_channel("commentary").with_recipient(
|
||||||
"functions.get_current_weather").with_content_type("json"),
|
"functions.get_current_weather").with_content_type("json"),
|
||||||
])
|
])
|
||||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
@ -121,6 +126,17 @@ def test_extract_tool_calls_multiple_tools(
|
|||||||
Role.ASSISTANT,
|
Role.ASSISTANT,
|
||||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
"functions.get_user_location").with_content_type("json"),
|
"functions.get_user_location").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT, '{"location": "Tokyo"}').with_channel(
|
||||||
|
"commentary").with_recipient("functions.no_content_type"),
|
||||||
|
Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel(
|
||||||
|
"commentary").with_recipient("functions.not_json_no_content_type"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT, '{}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.empty_args").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT, '').with_channel("commentary").with_recipient(
|
||||||
|
"functions.no_args").with_content_type("json"),
|
||||||
])
|
])
|
||||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
convo,
|
convo,
|
||||||
@ -141,7 +157,63 @@ def test_extract_tool_calls_multiple_tools(
|
|||||||
ToolCall(function=FunctionCall(
|
ToolCall(function=FunctionCall(
|
||||||
name="get_user_location",
|
name="get_user_location",
|
||||||
arguments=json.dumps({"location": "Tokyo"}),
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="no_content_type",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="not_json_no_content_type",
|
||||||
|
arguments="foo",
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="empty_args",
|
||||||
|
arguments=json.dumps({}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="no_args",
|
||||||
|
arguments="",
|
||||||
))
|
))
|
||||||
]
|
]
|
||||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||||
assert extracted_info.content is None
|
assert extracted_info.content is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_content(
|
||||||
|
openai_tool_parser,
|
||||||
|
harmony_encoding,
|
||||||
|
):
|
||||||
|
final_content = "This tool call will get the weather."
|
||||||
|
convo = Conversation.from_messages([
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||||
|
).with_channel("analysis"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.get_current_weather").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(Role.ASSISTANT,
|
||||||
|
final_content).with_channel("final"),
|
||||||
|
])
|
||||||
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
|
convo,
|
||||||
|
Role.ASSISTANT,
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||||
|
"",
|
||||||
|
request=None,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
assert extracted_info.tools_called
|
||||||
|
expected_tool_calls = [
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
]
|
||||||
|
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||||
|
assert extracted_info.content == final_content
|
||||||
|
|||||||
@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
|||||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||||
max_model_len=256,
|
max_model_len=256,
|
||||||
max_seq_len_to_capture=256,
|
|
||||||
max_num_seqs=8,
|
max_num_seqs=8,
|
||||||
tensor_parallel_size=tp,
|
tensor_parallel_size=tp,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
|
|||||||
@ -9,8 +9,9 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
|||||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||||
_make_metadata_with_slice,
|
_make_metadata_with_slice,
|
||||||
slice_query_start_locs,
|
slice_query_start_locs,
|
||||||
split_attn_metadata)
|
split_attn_metadata,
|
||||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
split_decodes_and_prefills)
|
||||||
|
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
|||||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_split_decodes_and_prefills(query_lens: list[int],
|
||||||
|
decode_threshold: int,
|
||||||
|
require_uniform: bool):
|
||||||
|
"""Helper function to apply split_decodes_and_prefills and return
|
||||||
|
the results."""
|
||||||
|
device = torch.device("cpu")
|
||||||
|
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
|
||||||
|
common_metadata = create_common_attn_metadata(BatchSpec(
|
||||||
|
seq_lens=seq_lens, query_lens=query_lens),
|
||||||
|
block_size=16,
|
||||||
|
device=device)
|
||||||
|
return split_decodes_and_prefills(common_metadata,
|
||||||
|
decode_threshold=decode_threshold,
|
||||||
|
require_uniform=require_uniform)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||||
|
query_lens = [1, 1, 1]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 1, False))
|
||||||
|
assert num_decodes == 3
|
||||||
|
assert num_prefills == 0
|
||||||
|
assert num_decode_tokens == 3
|
||||||
|
assert num_prefill_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||||
|
query_lens = [1, 2, 1, 3, 2, 1, 2]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||||
|
assert num_decodes == 7
|
||||||
|
assert num_prefills == 0
|
||||||
|
assert num_decode_tokens == sum(query_lens)
|
||||||
|
assert num_prefill_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||||
|
query_lens = [4, 5, 6, 7]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||||
|
assert num_decodes == 0
|
||||||
|
assert num_prefills == 4
|
||||||
|
assert num_decode_tokens == 0
|
||||||
|
assert num_prefill_tokens == sum(query_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||||
|
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 4, False))
|
||||||
|
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
|
||||||
|
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
|
||||||
|
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
|
||||||
|
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_uniform_all_ones():
|
||||||
|
query_lens = [1, 1, 1]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 1, True))
|
||||||
|
assert num_decodes == 3
|
||||||
|
assert num_prefills == 0
|
||||||
|
assert num_decode_tokens == 3
|
||||||
|
assert num_prefill_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||||
|
query_lens = [2, 2, 1, 3, 2, 1, 2]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||||
|
assert num_decodes == 2
|
||||||
|
assert num_prefills == 5
|
||||||
|
assert num_decode_tokens == 4
|
||||||
|
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||||
|
query_lens = [4, 5, 6, 7]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||||
|
assert num_decodes == 0
|
||||||
|
assert num_prefills == 4
|
||||||
|
assert num_decode_tokens == 0
|
||||||
|
assert num_prefill_tokens == sum(query_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||||
|
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||||
|
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
|
||||||
|
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
|
||||||
|
assert num_decode_tokens == 6 # 2 + 2 + 2
|
||||||
|
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||||
|
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||||
|
assert num_decodes == 1 # only the first 2 is taken as decode
|
||||||
|
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
|
||||||
|
assert num_decode_tokens == 2 # only the first 2
|
||||||
|
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
|||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256, sha256_cbor
|
from vllm.utils import sha256, sha256_cbor
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||||
get_block_hash, get_group_id,
|
KVCacheBlock, get_block_hash,
|
||||||
|
get_group_id,
|
||||||
get_request_block_hasher,
|
get_request_block_hasher,
|
||||||
hash_block_tokens, init_none_hash,
|
hash_block_tokens, init_none_hash,
|
||||||
make_block_hash_with_group_id)
|
make_block_hash_with_group_id)
|
||||||
@ -138,7 +139,7 @@ def test_prefill(hash_fn):
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
parent_block_hash = None
|
parent_block_hash = None
|
||||||
@ -171,7 +172,7 @@ def test_prefill(hash_fn):
|
|||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([5], )
|
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||||
for block in computed_blocks.blocks[0]:
|
for block in computed_blocks.blocks[0]:
|
||||||
assert block.ref_cnt == 2
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
@ -207,7 +208,7 @@ def test_prefill(hash_fn):
|
|||||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([6], )
|
assert blocks is not None and blocks.get_block_ids() == ([6], )
|
||||||
|
|
||||||
# Although we only have 6 free blocks, we have 8 blocks in
|
# Although we only have 6 free blocks, we have 8 blocks in
|
||||||
# the free block queue due to lazy removal.
|
# the free block queue due to lazy removal.
|
||||||
@ -227,7 +228,9 @@ def test_prefill(hash_fn):
|
|||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
# This block ID order also checks the eviction order.
|
# This block ID order also checks the eviction order.
|
||||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
assert blocks is not None and blocks.get_block_ids() == ([
|
||||||
|
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
|
||||||
|
], )
|
||||||
|
|
||||||
assert free_block_queue.num_free_blocks == 0
|
assert free_block_queue.num_free_blocks == 0
|
||||||
assert (free_block_queue.fake_free_list_head.next_free_block
|
assert (free_block_queue.fake_free_list_head.next_free_block
|
||||||
@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
|
||||||
8], [9, 10, 11, 12])
|
5, 6, 7, 8
|
||||||
|
], [9, 10, 11, 12])
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
parent_block_hash = None
|
parent_block_hash = None
|
||||||
@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
|
|||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
|
||||||
for block_per_group in computed_blocks.blocks:
|
for block_per_group in computed_blocks.blocks:
|
||||||
for block in block_per_group:
|
for block in block_per_group:
|
||||||
if block != manager.block_pool.null_block:
|
if block != manager.block_pool.null_block:
|
||||||
@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
|
|||||||
manager.free(req1)
|
manager.free(req1)
|
||||||
|
|
||||||
cached_block_hash_to_block_bak = copy.copy(
|
cached_block_hash_to_block_bak = copy.copy(
|
||||||
manager.block_pool.cached_block_hash_to_block)
|
manager.block_pool.cached_block_hash_to_block._cache)
|
||||||
|
|
||||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
def test_partial_request_hit(request_id: str,
|
||||||
|
hash_to_evict: list[BlockHashWithGroupId],
|
||||||
expect_hit_length: int):
|
expect_hit_length: int):
|
||||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||||
block_size, sha256)
|
block_size, sha256)
|
||||||
for hash_with_group_id in hash_to_evict:
|
for hash_with_group_id in hash_to_evict:
|
||||||
manager.block_pool.cached_block_hash_to_block.pop(
|
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||||
hash_with_group_id)
|
hash_with_group_id)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||||
assert len(req.block_hashes) == 3
|
assert len(req.block_hashes) == 3
|
||||||
@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
|
|||||||
for block_per_group in computed_blocks.blocks:
|
for block_per_group in computed_blocks.blocks:
|
||||||
assert len(block_per_group) == num_computed_tokens // block_size
|
assert len(block_per_group) == num_computed_tokens // block_size
|
||||||
for hash_with_group_id in hash_to_evict:
|
for hash_with_group_id in hash_to_evict:
|
||||||
manager.block_pool.cached_block_hash_to_block[
|
manager.block_pool.cached_block_hash_to_block._cache[
|
||||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||||
hash_with_group_id]
|
hash_with_group_id]
|
||||||
manager.free(req)
|
manager.free(req)
|
||||||
@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
|
|||||||
# total cache miss.
|
# total cache miss.
|
||||||
# The cache hit length of full attention is 1 * block_size.
|
# The cache hit length of full attention is 1 * block_size.
|
||||||
# The cache hit length of sliding window is 2 * block_size.
|
# The cache hit length of sliding window is 2 * block_size.
|
||||||
# Then it is cache miss as the two type of layers have different hit length.
|
# Then it is cache miss as the two type of layers
|
||||||
|
# have different hit length.
|
||||||
test_partial_request_hit("8", [
|
test_partial_request_hit("8", [
|
||||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
@ -406,7 +412,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
@ -441,7 +447,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([5], )
|
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||||
for block in computed_blocks.blocks[0]:
|
for block in computed_blocks.blocks[0]:
|
||||||
assert block.ref_cnt == 2
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
@ -478,6 +484,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req2, 55,
|
blocks = manager.allocate_slots(req2, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
|
assert blocks is not None
|
||||||
block_ids = blocks.get_block_ids()
|
block_ids = blocks.get_block_ids()
|
||||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||||
@ -513,7 +520,7 @@ def test_decode():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Append slots without allocating a new block.
|
# Append slots without allocating a new block.
|
||||||
req0.num_computed_tokens = 55
|
req0.num_computed_tokens = 55
|
||||||
@ -558,7 +565,8 @@ def test_evict():
|
|||||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
# 5 full + 1 partial
|
||||||
|
assert blocks is not None and len(blocks.blocks[0]) == 6
|
||||||
|
|
||||||
# 3 blocks.
|
# 3 blocks.
|
||||||
req1 = make_request("1", list(range(last_token_id,
|
req1 = make_request("1", list(range(last_token_id,
|
||||||
@ -570,7 +578,7 @@ def test_evict():
|
|||||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||||
last_token_id += 3 * 16
|
last_token_id += 3 * 16
|
||||||
|
|
||||||
# 10 - (6 + 3) == 1
|
# 10 - (6 + 3) == 1
|
||||||
@ -592,7 +600,7 @@ def test_evict():
|
|||||||
blocks = manager.allocate_slots(req2, 3,
|
blocks = manager.allocate_slots(req2, 3,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([10], )
|
assert blocks is not None and blocks.get_block_ids() == ([10], )
|
||||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||||
|
|
||||||
|
|
||||||
@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
|
|||||||
blocks = manager.allocate_slots(req, num_tokens,
|
blocks = manager.allocate_slots(req, num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
|
|
||||||
# Deallocate the block.
|
# Deallocate the block.
|
||||||
manager.free(req)
|
manager.free(req)
|
||||||
@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
|
|||||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
|
|
||||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||||
[0].block_id].block_hash is None
|
[0].block_id].block_hash is None
|
||||||
@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
blocks = manager.allocate_slots(req0, num_tokens,
|
blocks = manager.allocate_slots(req0, num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
assert blocks.blocks[0][0].block_id == 1
|
assert blocks.blocks[0][0].block_id == 1
|
||||||
|
|
||||||
# Allocate another block.
|
# Allocate another block.
|
||||||
@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
blocks = manager.allocate_slots(req1, num_tokens,
|
blocks = manager.allocate_slots(req1, num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
assert blocks.blocks[0][0].block_id == 2
|
assert blocks.blocks[0][0].block_id == 2
|
||||||
|
|
||||||
# Free the blocks.
|
# Free the blocks.
|
||||||
@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
assert blocks.blocks[0][0].block_id == 2
|
assert blocks.blocks[0][0].block_id == 2
|
||||||
|
|
||||||
|
|
||||||
@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
blocks = manager.allocate_slots(req1, 10,
|
blocks = manager.allocate_slots(req1, 10,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 3
|
assert blocks is not None and len(blocks.blocks[0]) == 3
|
||||||
|
|
||||||
# Free the blocks.
|
# Free the blocks.
|
||||||
manager.free(req1)
|
manager.free(req1)
|
||||||
@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
blocks = manager.allocate_slots(req2, 16,
|
blocks = manager.allocate_slots(req2, 16,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 4
|
assert blocks is not None and len(blocks.blocks[0]) == 4
|
||||||
|
|
||||||
# New requests should not have any blocks.
|
# New requests should not have any blocks.
|
||||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||||
@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
|
|||||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||||
assert all([block.block_hash is not None for block in blocks])
|
assert all([block.block_hash is not None for block in blocks])
|
||||||
|
|
||||||
# Test that blocks that don't start from the beginning are cached correctly.
|
# Test that blocks that don't start from the beginning are cached
|
||||||
|
# correctly.
|
||||||
blocks += [KVCacheBlock(block_id=2)]
|
blocks += [KVCacheBlock(block_id=2)]
|
||||||
block_pool.cache_full_blocks(
|
block_pool.cache_full_blocks(
|
||||||
request=req,
|
request=req,
|
||||||
@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
|
|||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||||
blocks = manager.allocate_slots(req0, 55)
|
blocks = manager.allocate_slots(req0, 55)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
unique_token_ids = [4] * 7
|
unique_token_ids = [4] * 7
|
||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
|
|||||||
blocks = manager.allocate_slots(req1, 7,
|
blocks = manager.allocate_slots(req1, 7,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([5], )
|
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||||
|
|
||||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||||
assert not manager.reset_prefix_cache()
|
assert not manager.reset_prefix_cache()
|
||||||
@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
|
|||||||
# Manually add all blocks to cached_blocks
|
# Manually add all blocks to cached_blocks
|
||||||
for block, block_hash in zip(pool.blocks, block_hashes):
|
for block, block_hash in zip(pool.blocks, block_hashes):
|
||||||
block.block_hash = block_hash
|
block.block_hash = block_hash
|
||||||
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
|
pool.cached_block_hash_to_block.insert(block_hash, block)
|
||||||
|
|
||||||
block0, block1, block2, block3 = pool.blocks
|
block0, block1, block2, block3 = pool.blocks
|
||||||
assert pool.cached_block_hash_to_block == {
|
assert pool.cached_block_hash_to_block._cache == {
|
||||||
block_hash0: {
|
block_hash0: {
|
||||||
block0.block_id: block0,
|
block0.block_id: block0,
|
||||||
block3.block_id: block3
|
block3.block_id: block3,
|
||||||
},
|
},
|
||||||
block_hash1: {
|
block_hash1: block1,
|
||||||
block1.block_id: block1
|
block_hash2: block2,
|
||||||
},
|
|
||||||
block_hash2: {
|
|
||||||
block2.block_id: block2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Evict block1
|
# Evict block1
|
||||||
pool._maybe_evict_cached_block(block1)
|
pool._maybe_evict_cached_block(block1)
|
||||||
assert pool.cached_block_hash_to_block == {
|
assert pool.cached_block_hash_to_block._cache == {
|
||||||
block_hash0: {
|
block_hash0: {
|
||||||
block0.block_id: block0,
|
block0.block_id: block0,
|
||||||
block3.block_id: block3
|
block3.block_id: block3
|
||||||
},
|
},
|
||||||
block_hash2: {
|
block_hash2: block2,
|
||||||
block2.block_id: block2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
||||||
# also use the same hash
|
# also use the same hash
|
||||||
pool._maybe_evict_cached_block(block0)
|
pool._maybe_evict_cached_block(block0)
|
||||||
assert pool.cached_block_hash_to_block == {
|
assert pool.cached_block_hash_to_block._cache == {
|
||||||
block_hash0: {
|
block_hash0: {
|
||||||
block3.block_id: block3
|
block3.block_id: block3
|
||||||
},
|
},
|
||||||
block_hash2: {
|
block_hash2: block2,
|
||||||
block2.block_id: block2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Evict block2
|
# Evict block2
|
||||||
pool._maybe_evict_cached_block(block2)
|
pool._maybe_evict_cached_block(block2)
|
||||||
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
|
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
|
||||||
# Evict block3
|
# Evict block3
|
||||||
pool._maybe_evict_cached_block(block3)
|
pool._maybe_evict_cached_block(block3)
|
||||||
assert pool.cached_block_hash_to_block == {}
|
assert pool.cached_block_hash_to_block._cache == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||||
@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
|
|||||||
# Evict the first block in the request
|
# Evict the first block in the request
|
||||||
assert manager.block_pool.get_cached_block(
|
assert manager.block_pool.get_cached_block(
|
||||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||||
manager.block_pool.cached_block_hash_to_block.pop(
|
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||||
|
|
||||||
# New request
|
# New request
|
||||||
@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
|
|||||||
# there will be no matched prefix.
|
# there will be no matched prefix.
|
||||||
assert len(computed_blocks.blocks[0]) == 0
|
assert len(computed_blocks.blocks[0]) == 0
|
||||||
assert num_tokens == 0
|
assert num_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_lookup_cache_single_block_per_key():
|
||||||
|
cache = BlockHashToBlockMap()
|
||||||
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
|
key1 = BlockHashWithGroupId(b"hash1")
|
||||||
|
key2 = BlockHashWithGroupId(b"hash2")
|
||||||
|
block0 = KVCacheBlock(0)
|
||||||
|
block1 = KVCacheBlock(1)
|
||||||
|
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# key0 inserted
|
||||||
|
cache.insert(key0, block0)
|
||||||
|
assert cache.get_one_block(key0) is block0
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# key1 inserted
|
||||||
|
cache.insert(key1, block1)
|
||||||
|
assert cache.get_one_block(key0) is block0
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# No block poped due to block_id mismatch
|
||||||
|
assert cache.pop(key0, 100) is None
|
||||||
|
assert cache.get_one_block(key0) is block0
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# block poped with (key0, block ID 0)
|
||||||
|
assert cache.pop(key0, 0) is block0
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# No block poped due to block_id mismatch
|
||||||
|
assert cache.pop(key0, 1) is None
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# block poped with (key1, block ID 1)
|
||||||
|
assert cache.pop(key1, 1) is block1
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_lookup_cache_multi_blocks_per_key():
|
||||||
|
cache = BlockHashToBlockMap()
|
||||||
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
|
key1 = BlockHashWithGroupId(b"hash1")
|
||||||
|
block00 = KVCacheBlock(0)
|
||||||
|
block01 = KVCacheBlock(1)
|
||||||
|
block10 = KVCacheBlock(10)
|
||||||
|
block11 = KVCacheBlock(11)
|
||||||
|
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
|
||||||
|
cache.insert(key0, block00)
|
||||||
|
cache.insert(key0, block01)
|
||||||
|
cache.insert(key1, block10)
|
||||||
|
cache.insert(key1, block11)
|
||||||
|
|
||||||
|
assert cache.get_one_block(key0) is block00
|
||||||
|
assert cache.pop(key0, 0) is block00
|
||||||
|
assert cache.get_one_block(key0) is block01
|
||||||
|
assert cache.pop(key0, 1) is block01
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.pop(key0, 2) is None
|
||||||
|
|
||||||
|
assert cache.get_one_block(key1) is block10
|
||||||
|
assert cache.pop(key1, 10) is block10
|
||||||
|
assert cache.get_one_block(key1) is block11
|
||||||
|
assert cache.pop(key1, 11) is block11
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.pop(key1, 12) is None
|
||||||
|
|||||||
@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block._cache.clear()
|
||||||
|
|
||||||
# Mock the block pool with the cached blocks
|
# Mock the block pool with the cached blocks
|
||||||
for i, (block_hash,
|
for i, (block_hash,
|
||||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||||
if is_cached:
|
if is_cached:
|
||||||
block_pool.cached_block_hash_to_block[
|
block_pool.cached_block_hash_to_block.insert(
|
||||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
make_block_hash_with_group_id(block_hash, 0),
|
||||||
i: block_pool.blocks[i + 10],
|
block_pool.blocks[i + 10])
|
||||||
}
|
|
||||||
|
|
||||||
computed_blocks = manager.find_longest_cache_hit(
|
computed_blocks = manager.find_longest_cache_hit(
|
||||||
block_hashes=block_hash_list,
|
block_hashes=block_hash_list,
|
||||||
@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block._cache.clear()
|
||||||
|
|
||||||
# Mock the block pool with the cached blocks
|
# Mock the block pool with the cached blocks
|
||||||
for i, (block_hash,
|
for i, (block_hash,
|
||||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||||
if is_cached:
|
if is_cached:
|
||||||
block_pool.cached_block_hash_to_block[
|
block_pool.cached_block_hash_to_block.insert(
|
||||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
make_block_hash_with_group_id(block_hash, 0),
|
||||||
i: block_pool.blocks[i + 10],
|
block_pool.blocks[i + 10])
|
||||||
}
|
|
||||||
|
|
||||||
computed_blocks = manager.find_longest_cache_hit(
|
computed_blocks = manager.find_longest_cache_hit(
|
||||||
block_hashes=block_hash_list,
|
block_hashes=block_hash_list,
|
||||||
|
|||||||
@ -81,16 +81,6 @@ class CarDescription(BaseModel):
|
|||||||
car_type: CarType
|
car_type: CarType
|
||||||
|
|
||||||
|
|
||||||
def _load_json(s: str, backend: str) -> str:
|
|
||||||
if backend != "xgrammar":
|
|
||||||
return json.loads(s)
|
|
||||||
|
|
||||||
# xgrammar specific workarounds
|
|
||||||
# https://github.com/mlc-ai/xgrammar/issues/286
|
|
||||||
s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s)
|
|
||||||
return json.loads(s)
|
|
||||||
|
|
||||||
|
|
||||||
def test_guided_decoding_deprecated():
|
def test_guided_decoding_deprecated():
|
||||||
with pytest.warns(DeprecationWarning,
|
with pytest.warns(DeprecationWarning,
|
||||||
match="GuidedDecodingParams is deprecated.*"):
|
match="GuidedDecodingParams is deprecated.*"):
|
||||||
@ -177,7 +167,12 @@ def test_structured_output(
|
|||||||
if backend != 'lm-format-enforcer':
|
if backend != 'lm-format-enforcer':
|
||||||
assert "\n" not in generated_text
|
assert "\n" not in generated_text
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
output_json = json.loads(generated_text)
|
try:
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
pytest.fail(
|
||||||
|
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
|
||||||
|
f"Schema: {sample_json_schema}\nError: {e}")
|
||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -425,7 +420,12 @@ def test_structured_output(
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
output_json = json.loads(generated_text)
|
try:
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
pytest.fail(
|
||||||
|
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
|
||||||
|
f"Schema: {json_schema}\nError: {e}")
|
||||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -468,7 +468,12 @@ def test_structured_output(
|
|||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
assert generated_text is not None
|
assert generated_text is not None
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
output_json = json.loads(generated_text)
|
try:
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
pytest.fail(
|
||||||
|
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
|
||||||
|
f"Schema: {json_schema}\nError: {e}")
|
||||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
|
||||||
if backend not in ["outlines", "lm-format-enforcer"]:
|
if backend not in ["outlines", "lm-format-enforcer"]:
|
||||||
|
|||||||
@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
|
|||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"A robot may not injure a human being",
|
"A robot may not injure a human being",
|
||||||
"It is only with the heart that one can see rightly;",
|
|
||||||
"The greatest glory in living lies not in never falling,",
|
|
||||||
]
|
]
|
||||||
answers = [
|
answers = [
|
||||||
"or, being injured, not kill, except in",
|
"or kill a human being",
|
||||||
"without the heart, one can only see wrongly.",
|
|
||||||
"but in rising every time we fall. - Nelson"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||||
|
|||||||
174
tests/v1/worker/test_worker_memory_snapshot.py
Normal file
174
tests/v1/worker/test_worker_memory_snapshot.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from multiprocessing import Queue
|
||||||
|
from typing import Optional
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils import MemorySnapshot
|
||||||
|
from vllm.v1.worker.gpu_worker import (Worker,
|
||||||
|
init_worker_distributed_environment)
|
||||||
|
|
||||||
|
# Global queue to track operation order across processes
|
||||||
|
_QUEUE: Optional[Queue] = None
|
||||||
|
|
||||||
|
|
||||||
|
def track_operation(operation: str, rank: int):
|
||||||
|
"""Track when an operation happens and its rank."""
|
||||||
|
if _QUEUE is not None:
|
||||||
|
_QUEUE.put((operation, rank))
|
||||||
|
|
||||||
|
|
||||||
|
def make_operation_tracker(operation_name: str, original_func):
|
||||||
|
"""Create a mock function that tracks when an operation is called.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_name: Name to use when tracking this operation
|
||||||
|
original_func: The original function to wrap
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A wrapper function that tracks the operation and calls the original
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
rank = int(os.environ.get("RANK", "-1"))
|
||||||
|
track_operation(operation_name, rank)
|
||||||
|
return original_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def worker_process(rank: int, world_size: int, distributed_init_method: str,
|
||||||
|
queue: Queue, error_queue: Queue):
|
||||||
|
"""Worker process that initializes a GPU worker with proper tracking."""
|
||||||
|
global _QUEUE
|
||||||
|
_QUEUE = queue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set environment variables
|
||||||
|
os.environ["RANK"] = str(rank)
|
||||||
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
|
||||||
|
# Create vLLM config with small model
|
||||||
|
vllm_config = EngineArgs(model="facebook/opt-125m",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
load_format="dummy").create_engine_config()
|
||||||
|
|
||||||
|
# Create worker
|
||||||
|
worker = Worker(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
local_rank=rank,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get original functions before patching
|
||||||
|
original_init_worker = init_worker_distributed_environment
|
||||||
|
original_memory_snapshot_init = MemorySnapshot.__init__
|
||||||
|
original_all_reduce = torch.distributed.all_reduce
|
||||||
|
|
||||||
|
# Apply minimal patches to track operation order
|
||||||
|
init_patch = patch(
|
||||||
|
'vllm.v1.worker.gpu_worker.init_worker_distributed_environment',
|
||||||
|
side_effect=make_operation_tracker("init_distributed",
|
||||||
|
original_init_worker))
|
||||||
|
memory_patch = patch.object(
|
||||||
|
MemorySnapshot, '__init__',
|
||||||
|
make_operation_tracker("memory_snapshot",
|
||||||
|
original_memory_snapshot_init))
|
||||||
|
all_reduce_patch = patch('torch.distributed.all_reduce',
|
||||||
|
side_effect=make_operation_tracker(
|
||||||
|
"nccl_all_reduce", original_all_reduce))
|
||||||
|
|
||||||
|
with init_patch, memory_patch, all_reduce_patch:
|
||||||
|
|
||||||
|
# Initialize device (this is where we test the order)
|
||||||
|
worker.init_device()
|
||||||
|
|
||||||
|
# Load model to ensure everything works
|
||||||
|
worker.load_model()
|
||||||
|
|
||||||
|
# Signal success
|
||||||
|
queue.put(("success", rank))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_queue.put((rank, str(e), type(e).__name__))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs for tensor parallelism")
|
||||||
|
def test_init_distributed_is_called_before_memory_snapshot():
|
||||||
|
"""Test that distributed env is setup before memory snapshot.
|
||||||
|
|
||||||
|
This test makes sure during worker initialization, the initial memory
|
||||||
|
snapshot is taken after distributed env is setup to include all the buffers
|
||||||
|
allocated by distributed env.
|
||||||
|
"""
|
||||||
|
world_size = 2
|
||||||
|
|
||||||
|
# Create a temporary file for distributed init
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||||
|
distributed_init_method = f"file://{f.name}"
|
||||||
|
|
||||||
|
# Create queues for inter-process communication
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
operation_queue = ctx.Queue()
|
||||||
|
error_queue = ctx.Queue()
|
||||||
|
|
||||||
|
# Start worker processes
|
||||||
|
processes = []
|
||||||
|
for rank in range(world_size):
|
||||||
|
p = ctx.Process(target=worker_process,
|
||||||
|
args=(rank, world_size, distributed_init_method,
|
||||||
|
operation_queue, error_queue))
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
# Wait for all processes to complete
|
||||||
|
for p in processes:
|
||||||
|
p.join(timeout=60) # 60 second timeout
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
errors = []
|
||||||
|
while not error_queue.empty():
|
||||||
|
rank, error_msg, error_type = error_queue.get()
|
||||||
|
errors.append(f"Rank {rank}: {error_type}: {error_msg}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
pytest.fail("Worker processes failed:\n" + "\n".join(errors))
|
||||||
|
|
||||||
|
# Collect all operations from the queue
|
||||||
|
operations = []
|
||||||
|
while not operation_queue.empty():
|
||||||
|
operations.append(operation_queue.get())
|
||||||
|
|
||||||
|
# Verify we got operations from both ranks
|
||||||
|
print(f"Collected operations: {operations}")
|
||||||
|
|
||||||
|
# Check operations for each rank
|
||||||
|
for rank in range(world_size):
|
||||||
|
rank_ops = [op for op, r in operations if r == rank]
|
||||||
|
print(f"\nRank {rank} operations: {rank_ops}")
|
||||||
|
|
||||||
|
# Raises ValueError if the operation is not found
|
||||||
|
init_distributed = rank_ops.index("init_distributed")
|
||||||
|
nccl_all_reduce = rank_ops.index("nccl_all_reduce")
|
||||||
|
memory_snapshot = rank_ops.index("memory_snapshot")
|
||||||
|
|
||||||
|
# Verify order: init_distributed should happen before memory_snapshot
|
||||||
|
assert init_distributed < nccl_all_reduce < memory_snapshot, (
|
||||||
|
f"Rank {rank}: init_distributed (index {init_distributed}) "
|
||||||
|
f"must happen before nccl_all_reduce (index {nccl_all_reduce}) "
|
||||||
|
f"and memory_snapshot (index {memory_snapshot})")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
os.unlink(distributed_init_method.replace("file://", ""))
|
||||||
@ -1,314 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from itertools import accumulate
|
|
||||||
from typing import List, Optional, Tuple, Type
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
||||||
AttentionMetadata,
|
|
||||||
AttentionMetadataBuilder)
|
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
|
||||||
from vllm.utils import async_tensor_h2d
|
|
||||||
|
|
||||||
# Placeholder attention backend for models like Mamba and pooling models that
|
|
||||||
# lack attention.
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderAttentionBackend(AttentionBackend):
|
|
||||||
"""Placeholder backend for when no attention is needed."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_name() -> str:
|
|
||||||
return "NO_ATTENTION"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
|
||||||
return PlaceholderAttentionImpl
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
|
||||||
return PlaceholderAttentionMetadataBuilder
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
|
||||||
return PlaceholderAttentionMetadata
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
|
||||||
return CommonAttentionState
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_kv_cache_shape(
|
|
||||||
num_blocks: int,
|
|
||||||
block_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
) -> Tuple[int, ...]:
|
|
||||||
return (1, 1, 1, 1, 1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def swap_blocks(
|
|
||||||
src_kv_cache: torch.Tensor,
|
|
||||||
dst_kv_cache: torch.Tensor,
|
|
||||||
src_to_dst: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def copy_blocks(
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
src_to_dists: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PlaceholderAttentionMetadata(AttentionMetadata):
|
|
||||||
"""Attention metadata for prefill and decode batched together."""
|
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
|
||||||
seq_lens: Optional[List[int]]
|
|
||||||
# seq_lens stored as a tensor.
|
|
||||||
seq_lens_tensor: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
||||||
# requests only.
|
|
||||||
max_prefill_seq_len: int
|
|
||||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
||||||
# requests only.
|
|
||||||
max_decode_seq_len: int
|
|
||||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
||||||
# so far).
|
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# Whether or not if cuda graph is enabled.
|
|
||||||
# Cuda-graph is currently enabled for decoding only.
|
|
||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
||||||
use_cuda_graph: bool
|
|
||||||
|
|
||||||
# Maximum query length in the batch.
|
|
||||||
max_query_len: Optional[int]
|
|
||||||
|
|
||||||
# Max number of query tokens among request in the batch.
|
|
||||||
max_decode_query_len: Optional[int]
|
|
||||||
|
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
|
||||||
# is [4, 6], it is [0, 4, 10].
|
|
||||||
query_start_loc: Optional[torch.Tensor] = None
|
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
||||||
# [4, 6], it is [0, 4, 10].
|
|
||||||
seq_start_loc: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# Placeholder.
|
|
||||||
block_tables: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
|
||||||
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
|
||||||
if self.num_prefills == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self._cached_prefill_metadata is not None:
|
|
||||||
return self._cached_prefill_metadata
|
|
||||||
|
|
||||||
# Compute some attn_metadata fields which default to None
|
|
||||||
query_start_loc = (None if self.query_start_loc is None else
|
|
||||||
self.query_start_loc[:self.num_prefills + 1])
|
|
||||||
seq_lens = (None if self.seq_lens is None else
|
|
||||||
self.seq_lens[:self.num_prefills])
|
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
||||||
self.seq_lens_tensor[:self.num_prefills])
|
|
||||||
seq_start_loc = (None if self.seq_start_loc is None else
|
|
||||||
self.seq_start_loc[:self.num_prefills + 1])
|
|
||||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
||||||
self.context_lens_tensor[:self.num_prefills])
|
|
||||||
|
|
||||||
# Placeholders
|
|
||||||
slot_mapping = torch.empty(0)
|
|
||||||
block_tables = torch.empty(0)
|
|
||||||
|
|
||||||
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
|
||||||
num_prefills=self.num_prefills,
|
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
|
||||||
num_decode_tokens=0,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
max_decode_query_len=0,
|
|
||||||
max_query_len=self.max_query_len,
|
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
||||||
max_decode_seq_len=0,
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
seq_start_loc=seq_start_loc,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=block_tables,
|
|
||||||
use_cuda_graph=False,
|
|
||||||
)
|
|
||||||
return self._cached_prefill_metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
|
||||||
if self.num_decode_tokens == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self._cached_decode_metadata is not None:
|
|
||||||
return self._cached_decode_metadata
|
|
||||||
assert self.seq_lens_tensor is not None
|
|
||||||
|
|
||||||
# Placeholders
|
|
||||||
slot_mapping = torch.empty(0)
|
|
||||||
block_tables = torch.empty(0)
|
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
||||||
self.seq_lens_tensor[self.num_prefills:])
|
|
||||||
|
|
||||||
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
|
||||||
num_prefills=0,
|
|
||||||
num_prefill_tokens=0,
|
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
seq_lens=None,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
max_decode_query_len=self.max_decode_query_len,
|
|
||||||
max_query_len=None,
|
|
||||||
max_prefill_seq_len=0,
|
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
|
||||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
||||||
self.query_start_loc[self.num_prefills])
|
|
||||||
if self.query_start_loc is not None else None,
|
|
||||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
||||||
if self.seq_start_loc is not None else None,
|
|
||||||
context_lens_tensor=None,
|
|
||||||
block_tables=block_tables,
|
|
||||||
use_cuda_graph=self.use_cuda_graph,
|
|
||||||
)
|
|
||||||
return self._cached_decode_metadata
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderAttentionMetadataBuilder(
|
|
||||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
|
||||||
|
|
||||||
def __init__(self, input_builder):
|
|
||||||
|
|
||||||
self.input_builder = input_builder
|
|
||||||
self.runner = input_builder.runner
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
self.prefill_seq_lens: List[int] = []
|
|
||||||
self.context_lens: List[int] = []
|
|
||||||
self.curr_seq_lens: List[int] = []
|
|
||||||
self.num_prefills = 0
|
|
||||||
self.num_prefill_tokens = 0
|
|
||||||
self.num_decode_tokens = 0
|
|
||||||
|
|
||||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
|
|
||||||
"""Add a sequence group to the metadata. Specifically update/append
|
|
||||||
1. context length.
|
|
||||||
"""
|
|
||||||
is_prompt = inter_data.is_prompt
|
|
||||||
|
|
||||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
||||||
curr_sliding_window_block) in zip(
|
|
||||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
||||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
||||||
inter_data.query_lens, inter_data.context_lens,
|
|
||||||
inter_data.curr_sliding_window_blocks):
|
|
||||||
self.context_lens.append(context_len)
|
|
||||||
|
|
||||||
if is_prompt:
|
|
||||||
self.num_prefills += 1
|
|
||||||
self.num_prefill_tokens += token_len
|
|
||||||
self.prefill_seq_lens.append(seq_len)
|
|
||||||
else:
|
|
||||||
self.num_decode_tokens += query_len
|
|
||||||
self.curr_seq_lens.append(curr_seq_len)
|
|
||||||
|
|
||||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
|
||||||
cuda_graph_pad_size: int, batch_size: int):
|
|
||||||
"""Build attention metadata with on-device tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
||||||
query_lens: The query lengths of the input sequences.
|
|
||||||
cuda_graph_pad_size: The padding size for cuda graph.
|
|
||||||
-1 if cuda graph is not used.
|
|
||||||
batch_size: The maybe padded batch size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Some input builders such as ModelInputForCPUBuilder do not have the
|
|
||||||
# "inter_data_list" attribute.
|
|
||||||
# Let's check inter_data_list exists before we reference it.
|
|
||||||
if hasattr(self.input_builder, "inter_data_list"):
|
|
||||||
for inter_data in self.input_builder.inter_data_list:
|
|
||||||
self._add_seq_group(inter_data,
|
|
||||||
self.input_builder.chunked_prefill_enabled)
|
|
||||||
|
|
||||||
device = self.runner.device
|
|
||||||
use_captured_graph = cuda_graph_pad_size != -1
|
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
|
||||||
decode_query_lens = query_lens[self.num_prefills:]
|
|
||||||
if len(decode_query_lens) > 0:
|
|
||||||
max_decode_query_len = max(decode_query_lens)
|
|
||||||
else:
|
|
||||||
max_decode_query_len = 1
|
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
||||||
num_decode_tokens = self.num_decode_tokens
|
|
||||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
||||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
||||||
|
|
||||||
if use_captured_graph:
|
|
||||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
|
||||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
||||||
|
|
||||||
assert device is not None
|
|
||||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
||||||
device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
|
|
||||||
# Placeholders
|
|
||||||
slot_mapping_tensor = torch.empty(0)
|
|
||||||
block_tables = torch.empty(0)
|
|
||||||
|
|
||||||
return PlaceholderAttentionMetadata(
|
|
||||||
num_prefills=self.num_prefills,
|
|
||||||
slot_mapping=slot_mapping_tensor,
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
max_decode_query_len=max_decode_query_len,
|
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
|
||||||
query_start_loc=query_start_loc_tensor,
|
|
||||||
seq_start_loc=seq_start_loc_tensor,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=block_tables,
|
|
||||||
use_cuda_graph=use_captured_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderAttentionImpl(AttentionImpl):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState):
|
|||||||
max_query_len=1,
|
max_query_len=1,
|
||||||
max_decode_query_len=1,
|
max_decode_query_len=1,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
max_decode_seq_len=self.runner.max_model_len,
|
||||||
query_start_loc=None,
|
query_start_loc=None,
|
||||||
seq_start_loc=None,
|
seq_start_loc=None,
|
||||||
context_lens_tensor=None,
|
context_lens_tensor=None,
|
||||||
@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState):
|
|||||||
dtype=torch.int).cuda()
|
dtype=torch.int).cuda()
|
||||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
|
||||||
attn_metadata.num_encoder_tokens = 0
|
attn_metadata.num_encoder_tokens = 0
|
||||||
|
|
||||||
def _add_additional_input_buffers_for_enc_dec_model(
|
def _add_additional_input_buffers_for_enc_dec_model(
|
||||||
|
|||||||
@ -115,12 +115,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
if cache_config is not None:
|
if cache_config is not None:
|
||||||
kv_cache_dtype = cache_config.cache_dtype
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
block_size = cache_config.block_size
|
block_size = cache_config.block_size
|
||||||
is_attention_free = cache_config.is_attention_free
|
|
||||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||||
else:
|
else:
|
||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
block_size = 16
|
block_size = 16
|
||||||
is_attention_free = False
|
|
||||||
calculate_kv_scales = False
|
calculate_kv_scales = False
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = num_heads
|
num_kv_heads = num_heads
|
||||||
@ -185,7 +183,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
dtype,
|
dtype,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
block_size,
|
block_size,
|
||||||
is_attention_free,
|
|
||||||
use_mla=use_mla,
|
use_mla=use_mla,
|
||||||
has_sink=self.has_sink)
|
has_sink=self.has_sink)
|
||||||
else:
|
else:
|
||||||
@ -578,9 +575,7 @@ def unified_attention_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="unified_attention",
|
op_name="unified_attention",
|
||||||
op_func=unified_attention,
|
op_func=unified_attention,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=unified_attention_fake,
|
fake_impl=unified_attention_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
tags=tag_cudagraph_unsafe,
|
tags=tag_cudagraph_unsafe,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -631,6 +626,5 @@ direct_register_custom_op(
|
|||||||
op_func=unified_attention_with_output,
|
op_func=unified_attention_with_output,
|
||||||
mutates_args=["output", "output_block_scale"],
|
mutates_args=["output", "output_block_scale"],
|
||||||
fake_impl=unified_attention_with_output_fake,
|
fake_impl=unified_attention_with_output_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
tags=tag_cudagraph_unsafe,
|
tags=tag_cudagraph_unsafe,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,10 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@ -142,7 +142,6 @@ def get_attn_backend(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
is_attention_free: bool = False,
|
|
||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
has_sink: bool = False,
|
has_sink: bool = False,
|
||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
@ -156,7 +155,6 @@ def get_attn_backend(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
is_attention_free=is_attention_free,
|
|
||||||
use_v1=envs.VLLM_USE_V1,
|
use_v1=envs.VLLM_USE_V1,
|
||||||
use_mla=use_mla,
|
use_mla=use_mla,
|
||||||
has_sink=has_sink,
|
has_sink=has_sink,
|
||||||
@ -169,17 +167,10 @@ def _cached_get_attn_backend(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
is_attention_free: bool,
|
|
||||||
use_v1: bool = False,
|
use_v1: bool = False,
|
||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
has_sink: bool = False,
|
has_sink: bool = False,
|
||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
# If there are no attention layers (e.g. we are running Mamba),
|
|
||||||
# use the placeholder NO_ATTENTION
|
|
||||||
if is_attention_free:
|
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
|
||||||
PlaceholderAttentionBackend)
|
|
||||||
return PlaceholderAttentionBackend
|
|
||||||
|
|
||||||
# Check whether a particular choice of backend was
|
# Check whether a particular choice of backend was
|
||||||
# previously forced.
|
# previously forced.
|
||||||
|
|||||||
@ -547,7 +547,6 @@ if flashinfer_comm is not None:
|
|||||||
"scale_out",
|
"scale_out",
|
||||||
],
|
],
|
||||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
flashinfer_trtllm_fused_allreduce_norm = (
|
flashinfer_trtllm_fused_allreduce_norm = (
|
||||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
|
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
|
||||||
|
|||||||
@ -551,8 +551,9 @@ def set_inductor_config(config, runtime_shape):
|
|||||||
if isinstance(runtime_shape, int):
|
if isinstance(runtime_shape, int):
|
||||||
# for a specific batchsize, tuning triton kernel parameters
|
# for a specific batchsize, tuning triton kernel parameters
|
||||||
# can be beneficial
|
# can be beneficial
|
||||||
config["max_autotune"] = True
|
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||||
config["coordinate_descent_tuning"] = True
|
config["coordinate_descent_tuning"] = (
|
||||||
|
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
|
||||||
|
|
||||||
|
|
||||||
class EagerAdaptor(CompilerInterface):
|
class EagerAdaptor(CompilerInterface):
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from packaging import version
|
||||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
@ -300,13 +301,13 @@ def _support_torch_compile(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"enable_cpp_symbolic_shape_guards config not available")
|
"enable_cpp_symbolic_shape_guards config not available")
|
||||||
|
|
||||||
with patch.object(InliningInstructionTranslator, 'inline_call',
|
with patch.object(
|
||||||
patched_inline_call), torch._dynamo.config.patch(
|
InliningInstructionTranslator, "inline_call",
|
||||||
**dynamo_config_patches
|
patched_inline_call), torch._dynamo.config.patch(
|
||||||
), maybe_use_cudagraph_partition_wrapper(
|
**dynamo_config_patches
|
||||||
self.vllm_config):
|
), maybe_use_cudagraph_partition_wrapper(
|
||||||
|
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||||
output = self.compiled_callable(*args, **kwargs)
|
output = self.compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
|||||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
and compilation_config.use_inductor_graph_partition):
|
and compilation_config.use_inductor_graph_partition):
|
||||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _torch27_patch_tensor_subclasses():
|
||||||
|
"""
|
||||||
|
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||||
|
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||||
|
`BasevLLMParameters` without having to replace them with regular tensors
|
||||||
|
before `torch.compile`-time.
|
||||||
|
"""
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
RowvLLMParameter,
|
||||||
|
_ColumnvLLMParameter)
|
||||||
|
|
||||||
|
def return_false(*args, **kwargs):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if version.parse("2.7") <= version.parse(
|
||||||
|
torch.__version__) < version.parse("2.8"):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||||
|
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||||
|
RowvLLMParameter
|
||||||
|
]),
|
||||||
|
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||||
|
return_false)):
|
||||||
|
yield
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
|||||||
PrefixCachingHashAlgo)
|
PrefixCachingHashAlgo)
|
||||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||||
CUDAGraphMode, PassConfig)
|
CUDAGraphMode, PassConfig)
|
||||||
|
from vllm.config.device import Device, DeviceConfig
|
||||||
from vllm.config.kv_events import KVEventsConfig
|
from vllm.config.kv_events import KVEventsConfig
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
from vllm.config.kv_transfer import KVTransferConfig
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
|
|||||||
try_match_architecture_defaults)
|
try_match_architecture_defaults)
|
||||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||||
MultiModalConfig)
|
MultiModalConfig)
|
||||||
|
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
|
||||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||||
ParallelConfig)
|
ParallelConfig)
|
||||||
from vllm.config.pooler import PoolerConfig
|
from vllm.config.pooler import PoolerConfig
|
||||||
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
||||||
from vllm.config.speculative import SpeculativeConfig
|
from vllm.config.speculative import SpeculativeConfig
|
||||||
|
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||||
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
|
||||||
class DeviceConfig:
|
|
||||||
"""Configuration for the device to use for vLLM execution."""
|
|
||||||
|
|
||||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
|
||||||
"""Device type for vLLM execution.
|
|
||||||
This parameter is deprecated and will be
|
|
||||||
removed in a future release.
|
|
||||||
It will now be set automatically based
|
|
||||||
on the current platform."""
|
|
||||||
device_type: str = field(init=False)
|
|
||||||
"""Device type from the current platform. This is set in
|
|
||||||
`__post_init__`."""
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
|
||||||
"""
|
|
||||||
WARNING: Whenever a new field is added to this config,
|
|
||||||
ensure that it is included in the factors list if
|
|
||||||
it affects the computation graph.
|
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
|
||||||
# no factors to consider.
|
|
||||||
# the device/platform information will be summarized
|
|
||||||
# by torch/vllm automatically.
|
|
||||||
factors: list[Any] = []
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
|
||||||
return hash_str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.device == "auto":
|
|
||||||
# Automated device type detection
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
self.device_type = current_platform.device_type
|
|
||||||
if not self.device_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Failed to infer device type, please set "
|
|
||||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
|
||||||
"to turn on verbose logging to help debug the issue.")
|
|
||||||
else:
|
|
||||||
# Device type is assigned explicitly
|
|
||||||
if isinstance(self.device, str):
|
|
||||||
self.device_type = self.device
|
|
||||||
elif isinstance(self.device, torch.device):
|
|
||||||
self.device_type = self.device.type
|
|
||||||
|
|
||||||
# Some device types require processing inputs on CPU
|
|
||||||
if self.device_type in ["tpu"]:
|
|
||||||
self.device = None
|
|
||||||
else:
|
|
||||||
# Set device with device type
|
|
||||||
self.device = torch.device(self.device_type)
|
|
||||||
|
|
||||||
|
|
||||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass
|
|
||||||
class ObservabilityConfig:
|
|
||||||
"""Configuration for observability - metrics and tracing."""
|
|
||||||
|
|
||||||
show_hidden_metrics_for_version: Optional[str] = None
|
|
||||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
|
||||||
specified version. For example, if a previously deprecated metric has been
|
|
||||||
hidden since the v0.7.0 release, you use
|
|
||||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
|
||||||
you migrate to new metrics. The metric is likely to be removed completely
|
|
||||||
in an upcoming release."""
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def show_hidden_metrics(self) -> bool:
|
|
||||||
"""Check if the hidden metrics should be shown."""
|
|
||||||
if self.show_hidden_metrics_for_version is None:
|
|
||||||
return False
|
|
||||||
return version._prev_minor_version_was(
|
|
||||||
self.show_hidden_metrics_for_version)
|
|
||||||
|
|
||||||
otlp_traces_endpoint: Optional[str] = None
|
|
||||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
|
||||||
|
|
||||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
|
||||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
|
||||||
set, it will collect detailed traces for the specified modules. This
|
|
||||||
involves use of possibly costly and or blocking operations and hence might
|
|
||||||
have a performance impact.
|
|
||||||
|
|
||||||
Note that collecting detailed timing information for each request can be
|
|
||||||
expensive."""
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def collect_model_forward_time(self) -> bool:
|
|
||||||
"""Whether to collect model forward time for the request."""
|
|
||||||
return (self.collect_detailed_traces is not None
|
|
||||||
and ("model" in self.collect_detailed_traces
|
|
||||||
or "all" in self.collect_detailed_traces))
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def collect_model_execute_time(self) -> bool:
|
|
||||||
"""Whether to collect model execute time for the request."""
|
|
||||||
return (self.collect_detailed_traces is not None
|
|
||||||
and ("worker" in self.collect_detailed_traces
|
|
||||||
or "all" in self.collect_detailed_traces))
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
|
||||||
"""
|
|
||||||
WARNING: Whenever a new field is added to this config,
|
|
||||||
ensure that it is included in the factors list if
|
|
||||||
it affects the computation graph.
|
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
|
||||||
# no factors to consider.
|
|
||||||
# this config will not affect the computation graph.
|
|
||||||
factors: list[Any] = []
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
|
||||||
return hash_str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if (self.collect_detailed_traces is not None
|
|
||||||
and len(self.collect_detailed_traces) == 1
|
|
||||||
and "," in self.collect_detailed_traces[0]):
|
|
||||||
self._parse_collect_detailed_traces()
|
|
||||||
|
|
||||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
|
||||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"OpenTelemetry is not available. Unable to configure "
|
|
||||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
|
||||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
|
||||||
|
|
||||||
def _parse_collect_detailed_traces(self):
|
|
||||||
assert isinstance(self.collect_detailed_traces, list)
|
|
||||||
self.collect_detailed_traces = cast(
|
|
||||||
list[DetailedTraceModules],
|
|
||||||
self.collect_detailed_traces[0].split(","))
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
class VllmConfig:
|
class VllmConfig:
|
||||||
@ -1009,37 +860,6 @@ def get_layers_from_vllm_config(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass
|
|
||||||
class SpeechToTextConfig:
|
|
||||||
"""Configuration for speech-to-text models."""
|
|
||||||
|
|
||||||
sample_rate: float = 16_000
|
|
||||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
|
||||||
16kHz audio input. The input audio will be automatically resampled to this
|
|
||||||
rate before processing."""
|
|
||||||
|
|
||||||
max_audio_clip_s: int = 30
|
|
||||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
|
||||||
Audio longer than this will be split into smaller chunks if
|
|
||||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
|
||||||
|
|
||||||
overlap_chunk_second: int = 1
|
|
||||||
"""Overlap duration in seconds between consecutive audio chunks when
|
|
||||||
splitting long audio. This helps maintain context across chunk boundaries
|
|
||||||
and improves transcription quality at split points."""
|
|
||||||
|
|
||||||
min_energy_split_window_size: Optional[int] = 1600
|
|
||||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
|
||||||
audio chunks. The algorithm looks for the quietest moment within this
|
|
||||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
|
||||||
at 16kHz. If None, no chunking will be done."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allow_audio_chunking(self) -> bool:
|
|
||||||
return self.min_energy_split_window_size is not None
|
|
||||||
|
|
||||||
|
|
||||||
def update_config(config: DataclassInstanceT,
|
def update_config(config: DataclassInstanceT,
|
||||||
overrides: dict[str, Any]) -> DataclassInstanceT:
|
overrides: dict[str, Any]) -> DataclassInstanceT:
|
||||||
processed_overrides = {}
|
processed_overrides = {}
|
||||||
|
|||||||
74
vllm/config/device.py
Normal file
74
vllm/config/device.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import ConfigDict, SkipValidation
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
|
class DeviceConfig:
|
||||||
|
"""Configuration for the device to use for vLLM execution."""
|
||||||
|
|
||||||
|
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||||
|
"""Device type for vLLM execution.
|
||||||
|
This parameter is deprecated and will be
|
||||||
|
removed in a future release.
|
||||||
|
It will now be set automatically based
|
||||||
|
on the current platform."""
|
||||||
|
device_type: str = field(init=False)
|
||||||
|
"""Device type from the current platform. This is set in
|
||||||
|
`__post_init__`."""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# no factors to consider.
|
||||||
|
# the device/platform information will be summarized
|
||||||
|
# by torch/vllm automatically.
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.device == "auto":
|
||||||
|
# Automated device type detection
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
self.device_type = current_platform.device_type
|
||||||
|
if not self.device_type:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to infer device type, please set "
|
||||||
|
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||||
|
"to turn on verbose logging to help debug the issue.")
|
||||||
|
else:
|
||||||
|
# Device type is assigned explicitly
|
||||||
|
if isinstance(self.device, str):
|
||||||
|
self.device_type = self.device
|
||||||
|
elif isinstance(self.device, torch.device):
|
||||||
|
self.device_type = self.device.type
|
||||||
|
|
||||||
|
# Some device types require processing inputs on CPU
|
||||||
|
if self.device_type in ["tpu"]:
|
||||||
|
self.device = None
|
||||||
|
else:
|
||||||
|
# Set device with device type
|
||||||
|
self.device = torch.device(self.device_type)
|
||||||
@ -177,11 +177,6 @@ class ModelConfig:
|
|||||||
graph and always execute the model in eager mode. If False, we will use
|
graph and always execute the model in eager mode. If False, we will use
|
||||||
CUDA graph and eager execution in hybrid for maximal performance and
|
CUDA graph and eager execution in hybrid for maximal performance and
|
||||||
flexibility."""
|
flexibility."""
|
||||||
max_seq_len_to_capture: int = 8192
|
|
||||||
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
|
|
||||||
length larger than this, we fall back to eager mode. Additionally for
|
|
||||||
encoder-decoder models, if the sequence length of the encoder input is
|
|
||||||
larger than this, we fall back to the eager mode."""
|
|
||||||
max_logprobs: int = 20
|
max_logprobs: int = 20
|
||||||
"""Maximum number of log probabilities to return when `logprobs` is
|
"""Maximum number of log probabilities to return when `logprobs` is
|
||||||
specified in `SamplingParams`. The default value comes the default for the
|
specified in `SamplingParams`. The default value comes the default for the
|
||||||
@ -699,11 +694,12 @@ class ModelConfig:
|
|||||||
model: Model name or path
|
model: Model name or path
|
||||||
tokenizer: Tokenizer name or path
|
tokenizer: Tokenizer name or path
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
|
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_runai_obj_uri(model):
|
if is_runai_obj_uri(model):
|
||||||
object_storage_model = ObjectStorageModel()
|
object_storage_model = ObjectStorageModel(url=model)
|
||||||
object_storage_model.pull_files(
|
object_storage_model.pull_files(
|
||||||
model, allow_pattern=["*.model", "*.py", "*.json"])
|
model, allow_pattern=["*.model", "*.py", "*.json"])
|
||||||
self.model_weights = model
|
self.model_weights = model
|
||||||
@ -722,7 +718,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Only download tokenizer if needed and not already handled
|
# Only download tokenizer if needed and not already handled
|
||||||
if is_runai_obj_uri(tokenizer):
|
if is_runai_obj_uri(tokenizer):
|
||||||
object_storage_tokenizer = ObjectStorageModel()
|
object_storage_tokenizer = ObjectStorageModel(url=tokenizer)
|
||||||
object_storage_tokenizer.pull_files(model,
|
object_storage_tokenizer.pull_files(model,
|
||||||
ignore_pattern=[
|
ignore_pattern=[
|
||||||
"*.pt", "*.safetensors",
|
"*.pt", "*.safetensors",
|
||||||
@ -1023,21 +1019,8 @@ class ModelConfig:
|
|||||||
current_platform.verify_quantization(self.quantization)
|
current_platform.verify_quantization(self.quantization)
|
||||||
|
|
||||||
def _verify_cuda_graph(self) -> None:
|
def _verify_cuda_graph(self) -> None:
|
||||||
# The `max_seq_len_to_capture` was incorrectly
|
|
||||||
# based on the encoder's input length (448)
|
|
||||||
# but not the decoder's larger input length (1500).
|
|
||||||
# This change ensures the CUDA Graph captures the correct,
|
|
||||||
# larger sequence length, allowing it to work as intended.
|
|
||||||
effective_max_seq_len = self.max_model_len
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
effective_max_seq_len = max(
|
|
||||||
effective_max_seq_len,
|
|
||||||
getattr(self.hf_config, "max_source_positions", 0))
|
|
||||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
|
||||||
effective_max_seq_len)
|
|
||||||
# CUDAGraph capture not supported for encoder-decoder models on ROCm
|
# CUDAGraph capture not supported for encoder-decoder models on ROCm
|
||||||
unsupported_rocm = self.is_encoder_decoder
|
unsupported_rocm = self.is_encoder_decoder
|
||||||
|
|
||||||
if (unsupported_rocm and not self.enforce_eager
|
if (unsupported_rocm and not self.enforce_eager
|
||||||
and current_platform.is_rocm()):
|
and current_platform.is_rocm()):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
99
vllm/config/observability.py
Normal file
99
vllm/config/observability.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Literal, Optional, cast
|
||||||
|
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm import version
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class ObservabilityConfig:
|
||||||
|
"""Configuration for observability - metrics and tracing."""
|
||||||
|
|
||||||
|
show_hidden_metrics_for_version: Optional[str] = None
|
||||||
|
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||||
|
specified version. For example, if a previously deprecated metric has been
|
||||||
|
hidden since the v0.7.0 release, you use
|
||||||
|
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||||
|
you migrate to new metrics. The metric is likely to be removed completely
|
||||||
|
in an upcoming release."""
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def show_hidden_metrics(self) -> bool:
|
||||||
|
"""Check if the hidden metrics should be shown."""
|
||||||
|
if self.show_hidden_metrics_for_version is None:
|
||||||
|
return False
|
||||||
|
return version._prev_minor_version_was(
|
||||||
|
self.show_hidden_metrics_for_version)
|
||||||
|
|
||||||
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
|
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||||
|
|
||||||
|
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
||||||
|
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||||
|
set, it will collect detailed traces for the specified modules. This
|
||||||
|
involves use of possibly costly and or blocking operations and hence might
|
||||||
|
have a performance impact.
|
||||||
|
|
||||||
|
Note that collecting detailed timing information for each request can be
|
||||||
|
expensive."""
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def collect_model_forward_time(self) -> bool:
|
||||||
|
"""Whether to collect model forward time for the request."""
|
||||||
|
return (self.collect_detailed_traces is not None
|
||||||
|
and ("model" in self.collect_detailed_traces
|
||||||
|
or "all" in self.collect_detailed_traces))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def collect_model_execute_time(self) -> bool:
|
||||||
|
"""Whether to collect model execute time for the request."""
|
||||||
|
return (self.collect_detailed_traces is not None
|
||||||
|
and ("worker" in self.collect_detailed_traces
|
||||||
|
or "all" in self.collect_detailed_traces))
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# no factors to consider.
|
||||||
|
# this config will not affect the computation graph.
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if (self.collect_detailed_traces is not None
|
||||||
|
and len(self.collect_detailed_traces) == 1
|
||||||
|
and "," in self.collect_detailed_traces[0]):
|
||||||
|
self._parse_collect_detailed_traces()
|
||||||
|
|
||||||
|
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||||
|
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenTelemetry is not available. Unable to configure "
|
||||||
|
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||||
|
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||||
|
|
||||||
|
def _parse_collect_detailed_traces(self):
|
||||||
|
assert isinstance(self.collect_detailed_traces, list)
|
||||||
|
self.collect_detailed_traces = cast(
|
||||||
|
list[DetailedTraceModules],
|
||||||
|
self.collect_detailed_traces[0].split(","))
|
||||||
@ -285,8 +285,6 @@ class SpeculativeConfig:
|
|||||||
max_model_len,
|
max_model_len,
|
||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
enforce_eager=self.target_model_config.enforce_eager,
|
enforce_eager=self.target_model_config.enforce_eager,
|
||||||
max_seq_len_to_capture=self.target_model_config.
|
|
||||||
max_seq_len_to_capture,
|
|
||||||
max_logprobs=self.target_model_config.max_logprobs,
|
max_logprobs=self.target_model_config.max_logprobs,
|
||||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||||
)
|
)
|
||||||
|
|||||||
39
vllm/config/speech_to_text.py
Normal file
39
vllm/config/speech_to_text.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class SpeechToTextConfig:
|
||||||
|
"""Configuration for speech-to-text models."""
|
||||||
|
|
||||||
|
sample_rate: float = 16_000
|
||||||
|
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||||
|
16kHz audio input. The input audio will be automatically resampled to this
|
||||||
|
rate before processing."""
|
||||||
|
|
||||||
|
max_audio_clip_s: int = 30
|
||||||
|
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||||
|
Audio longer than this will be split into smaller chunks if
|
||||||
|
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||||
|
|
||||||
|
overlap_chunk_second: int = 1
|
||||||
|
"""Overlap duration in seconds between consecutive audio chunks when
|
||||||
|
splitting long audio. This helps maintain context across chunk boundaries
|
||||||
|
and improves transcription quality at split points."""
|
||||||
|
|
||||||
|
min_energy_split_window_size: Optional[int] = 1600
|
||||||
|
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||||
|
audio chunks. The algorithm looks for the quietest moment within this
|
||||||
|
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||||
|
at 16kHz. If None, no chunking will be done."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_audio_chunking(self) -> bool:
|
||||||
|
return self.min_energy_split_window_size is not None
|
||||||
@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="all_reduce_symmetric_with_copy",
|
op_name="all_reduce_symmetric_with_copy",
|
||||||
op_func=all_reduce_symmetric_with_copy_impl,
|
op_func=all_reduce_symmetric_with_copy_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=all_reduce_symmetric_with_copy_fake,
|
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -392,7 +392,8 @@ class MessageQueue:
|
|||||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
("No available shared memory broadcast block found"
|
("No available shared memory broadcast block found"
|
||||||
" in %s second."),
|
" in %s seconds. This typically happens when some"
|
||||||
|
" processes are hanging."),
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||||
)
|
)
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
@ -455,7 +456,8 @@ class MessageQueue:
|
|||||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
("No available shared memory broadcast block found"
|
("No available shared memory broadcast block found"
|
||||||
" in %s second."),
|
" in %s seconds. This typically happens when some"
|
||||||
|
" processes are hanging."),
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||||
)
|
)
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
|
|||||||
@ -574,7 +574,6 @@ class NixlConnectorWorker:
|
|||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.cache_config.cache_dtype,
|
self.cache_config.cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.model_config.is_attention_free,
|
|
||||||
use_mla=self.use_mla)
|
use_mla=self.use_mla)
|
||||||
self.backend_name = backend.get_name()
|
self.backend_name = backend.get_name()
|
||||||
attn_backend = backend_name_to_enum(self.backend_name)
|
attn_backend = backend_name_to_enum(self.backend_name)
|
||||||
|
|||||||
@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
|||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
from vllm.platforms import current_platform
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="all_reduce",
|
op_name="all_reduce",
|
||||||
op_func=all_reduce,
|
op_func=all_reduce,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=all_reduce_fake,
|
fake_impl=all_reduce_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="reduce_scatter",
|
op_name="reduce_scatter",
|
||||||
op_func=reduce_scatter,
|
op_func=reduce_scatter,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=reduce_scatter_fake,
|
fake_impl=reduce_scatter_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="all_gather",
|
op_name="all_gather",
|
||||||
op_func=all_gather,
|
op_func=all_gather,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=all_gather_fake,
|
fake_impl=all_gather_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -373,7 +373,6 @@ class EngineArgs:
|
|||||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||||
enforce_eager: bool = ModelConfig.enforce_eager
|
enforce_eager: bool = ModelConfig.enforce_eager
|
||||||
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
|
|
||||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||||
limit_mm_per_prompt: dict[str, int] = \
|
limit_mm_per_prompt: dict[str, int] = \
|
||||||
get_field(MultiModalConfig, "limit_per_prompt")
|
get_field(MultiModalConfig, "limit_per_prompt")
|
||||||
@ -545,8 +544,6 @@ class EngineArgs:
|
|||||||
**model_kwargs["quantization"])
|
**model_kwargs["quantization"])
|
||||||
model_group.add_argument("--enforce-eager",
|
model_group.add_argument("--enforce-eager",
|
||||||
**model_kwargs["enforce_eager"])
|
**model_kwargs["enforce_eager"])
|
||||||
model_group.add_argument("--max-seq-len-to-capture",
|
|
||||||
**model_kwargs["max_seq_len_to_capture"])
|
|
||||||
model_group.add_argument("--max-logprobs",
|
model_group.add_argument("--max-logprobs",
|
||||||
**model_kwargs["max_logprobs"])
|
**model_kwargs["max_logprobs"])
|
||||||
model_group.add_argument("--logprobs-mode",
|
model_group.add_argument("--logprobs-mode",
|
||||||
@ -1008,7 +1005,6 @@ class EngineArgs:
|
|||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
enforce_eager=self.enforce_eager,
|
enforce_eager=self.enforce_eager,
|
||||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
|
||||||
max_logprobs=self.max_logprobs,
|
max_logprobs=self.max_logprobs,
|
||||||
logprobs_mode=self.logprobs_mode,
|
logprobs_mode=self.logprobs_mode,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
|
|||||||
@ -130,11 +130,6 @@ class LLM:
|
|||||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
disable CUDA graph and always execute the model in eager mode.
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
If False, we will use CUDA graph and eager execution in hybrid.
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
|
||||||
When a sequence has context length larger than this, we fall back
|
|
||||||
to eager mode. Additionally for encoder-decoder models, if the
|
|
||||||
sequence length of the encoder input is larger than this, we fall
|
|
||||||
back to the eager mode.
|
|
||||||
disable_custom_all_reduce: See
|
disable_custom_all_reduce: See
|
||||||
[ParallelConfig][vllm.config.ParallelConfig].
|
[ParallelConfig][vllm.config.ParallelConfig].
|
||||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||||
@ -184,7 +179,6 @@ class LLM:
|
|||||||
swap_space: float = 4,
|
swap_space: float = 4,
|
||||||
cpu_offload_gb: float = 0,
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: bool = False,
|
||||||
max_seq_len_to_capture: int = 8192,
|
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
hf_token: Optional[Union[bool, str]] = None,
|
hf_token: Optional[Union[bool, str]] = None,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
hf_overrides: Optional[HfOverrides] = None,
|
||||||
@ -281,7 +275,6 @@ class LLM:
|
|||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
cpu_offload_gb=cpu_offload_gb,
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
hf_token=hf_token,
|
hf_token=hf_token,
|
||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
|
|||||||
@ -1186,6 +1186,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
|
reasoning_content, content, _ = parse_chat_output(token_ids)
|
||||||
|
if not request.include_reasoning:
|
||||||
|
reasoning_content = None
|
||||||
|
|
||||||
if self.tool_parser is not None:
|
if self.tool_parser is not None:
|
||||||
tool_parser = self.tool_parser(tokenizer)
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
# NOTE: We use token_ids for openai tool parser
|
# NOTE: We use token_ids for openai tool parser
|
||||||
@ -1194,10 +1198,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request=request,
|
request=request,
|
||||||
token_ids=token_ids, # type: ignore
|
token_ids=token_ids, # type: ignore
|
||||||
)
|
)
|
||||||
reasoning_content, content = None, tool_call_info.content
|
content = tool_call_info.content
|
||||||
if request.include_reasoning:
|
|
||||||
reasoning_content, content, _ = parse_chat_output(
|
|
||||||
token_ids)
|
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
@ -1205,10 +1206,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool_calls=tool_call_info.tool_calls,
|
tool_calls=tool_call_info.tool_calls,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reasoning_content, content, _ = parse_chat_output(
|
|
||||||
token_ids)
|
|
||||||
if not request.include_reasoning:
|
|
||||||
reasoning_content = None
|
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
|
|||||||
@ -235,8 +235,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
# Handle the previous response ID.
|
# Handle the previous response ID.
|
||||||
prev_response_id = request.previous_response_id
|
prev_response_id = request.previous_response_id
|
||||||
if prev_response_id is not None:
|
if prev_response_id is not None:
|
||||||
if not prev_response_id.startswith("resp_"):
|
|
||||||
return self._make_invalid_id_error(prev_response_id)
|
|
||||||
async with self.response_store_lock:
|
async with self.response_store_lock:
|
||||||
prev_response = self.response_store.get(prev_response_id)
|
prev_response = self.response_store.get(prev_response_id)
|
||||||
if prev_response is None:
|
if prev_response is None:
|
||||||
@ -924,9 +922,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
stream: Optional[bool],
|
stream: Optional[bool],
|
||||||
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[
|
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[
|
||||||
StreamingResponsesResponse, None]]:
|
StreamingResponsesResponse, None]]:
|
||||||
if not response_id.startswith("resp_"):
|
|
||||||
return self._make_invalid_id_error(response_id)
|
|
||||||
|
|
||||||
async with self.response_store_lock:
|
async with self.response_store_lock:
|
||||||
response = self.response_store.get(response_id)
|
response = self.response_store.get(response_id)
|
||||||
|
|
||||||
@ -944,9 +939,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
response_id: str,
|
response_id: str,
|
||||||
) -> Union[ErrorResponse, ResponsesResponse]:
|
) -> Union[ErrorResponse, ResponsesResponse]:
|
||||||
if not response_id.startswith("resp_"):
|
|
||||||
return self._make_invalid_id_error(response_id)
|
|
||||||
|
|
||||||
async with self.response_store_lock:
|
async with self.response_store_lock:
|
||||||
response = self.response_store.get(response_id)
|
response = self.response_store.get(response_id)
|
||||||
if response is None:
|
if response is None:
|
||||||
@ -972,13 +964,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
response_id)
|
response_id)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _make_invalid_id_error(self, response_id: str) -> ErrorResponse:
|
|
||||||
return self.create_error_response(
|
|
||||||
err_type="invalid_request_error",
|
|
||||||
message=(f"Invalid 'response_id': '{response_id}'. "
|
|
||||||
"Expected an ID that begins with 'resp'."),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_not_found_error(self, response_id: str) -> ErrorResponse:
|
def _make_not_found_error(self, response_id: str) -> ErrorResponse:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
err_type="invalid_request_error",
|
err_type="invalid_request_error",
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class DeepSeekV31ToolParser(ToolParser):
|
|||||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||||
|
|
||||||
self.tool_call_regex = re.compile(
|
self.tool_call_regex = re.compile(
|
||||||
r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>"
|
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stream_tool_call_portion_regex = re.compile(
|
self.stream_tool_call_portion_regex = re.compile(
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -12,10 +13,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser, ToolParserManager)
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ToolParserManager.register_module("openai")
|
@ToolParserManager.register_module("openai")
|
||||||
class OpenAIToolParser(ToolParser):
|
class OpenAIToolParser(ToolParser):
|
||||||
@ -40,17 +44,33 @@ class OpenAIToolParser(ToolParser):
|
|||||||
|
|
||||||
if len(parser.messages) > 0:
|
if len(parser.messages) > 0:
|
||||||
for msg in parser.messages:
|
for msg in parser.messages:
|
||||||
|
if len(msg.content) < 1:
|
||||||
|
continue
|
||||||
|
msg_text = msg.content[0].text
|
||||||
if msg.recipient and msg.recipient.startswith("functions."):
|
if msg.recipient and msg.recipient.startswith("functions."):
|
||||||
|
# If no content-type is given assume JSON, as that's the
|
||||||
|
# most common case with gpt-oss models.
|
||||||
|
if not msg.content_type or "json" in msg.content_type:
|
||||||
|
# load and dump the JSON text to check validity and
|
||||||
|
# remove any extra newlines or other odd formatting
|
||||||
|
try:
|
||||||
|
tool_args = json.dumps(json.loads(msg_text))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.exception(
|
||||||
|
"Error decoding JSON tool call from response.")
|
||||||
|
tool_args = msg_text
|
||||||
|
else:
|
||||||
|
tool_args = msg_text
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
type="function",
|
type="function",
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=msg.recipient.split("functions.")[1],
|
name=msg.recipient.split("functions.")[1],
|
||||||
arguments=msg.content[0].text,
|
arguments=tool_args,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
elif msg.channel == "final":
|
elif msg.channel == "final":
|
||||||
final_content = msg.content[0].text
|
final_content = msg_text
|
||||||
|
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=len(tool_calls) > 0,
|
tools_called=len(tool_calls) > 0,
|
||||||
|
|||||||
24
vllm/envs.py
24
vllm/envs.py
@ -119,7 +119,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SERVER_DEV_MODE: bool = False
|
VLLM_SERVER_DEV_MODE: bool = False
|
||||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||||
VLLM_MLA_DISABLE: bool = False
|
VLLM_MLA_DISABLE: bool = False
|
||||||
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
|
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32
|
||||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||||
@ -187,12 +187,15 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
|
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
|
||||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||||
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
|
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
|
||||||
VLLM_DBO_COMM_SMS: int = 20
|
VLLM_DBO_COMM_SMS: int = 20
|
||||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||||
|
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
|
||||||
|
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
|
||||||
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||||
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||||
|
|
||||||
@ -1014,7 +1017,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# max number splits for cuda graph decode
|
# max number splits for cuda graph decode
|
||||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
|
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
|
||||||
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||||
"16")),
|
"32")),
|
||||||
|
|
||||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||||
# it allows ray to schedule multiple actors on a single GPU,
|
# it allows ray to schedule multiple actors on a single GPU,
|
||||||
@ -1385,6 +1388,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||||
|
|
||||||
|
# Add optional nvtx scopes for profiling, disable to avoid overheads
|
||||||
|
"VLLM_NVTX_SCOPES_FOR_PROFILING":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0"))),
|
||||||
|
|
||||||
# Represent block hashes in KV cache events as 64-bit integers instead of
|
# Represent block hashes in KV cache events as 64-bit integers instead of
|
||||||
# raw bytes. Defaults to True for backward compatibility.
|
# raw bytes. Defaults to True for backward compatibility.
|
||||||
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
||||||
@ -1413,6 +1420,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
"web_search_preview"]),
|
"web_search_preview"]),
|
||||||
|
|
||||||
|
# Enable max_autotune & coordinate_descent_tuning in inductor_config
|
||||||
|
# to compile static shapes passed from compile_sizes in compilation_config
|
||||||
|
# If set to 1, enable max_autotune; By default, this is enabled (1)
|
||||||
|
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))),
|
||||||
|
# If set to 1, enable coordinate_descent_tuning;
|
||||||
|
# By default, this is enabled (1)
|
||||||
|
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||||
|
"1"))),
|
||||||
|
|
||||||
# Flag to enable NCCL symmetric memory allocation and registration
|
# Flag to enable NCCL symmetric memory allocation and registration
|
||||||
"VLLM_USE_NCCL_SYMM_MEM":
|
"VLLM_USE_NCCL_SYMM_MEM":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
||||||
@ -1513,6 +1531,8 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
||||||
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
||||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||||
|
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||||
|
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
# if this goes out of sync with environment_variables,
|
# if this goes out of sync with environment_variables,
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
|||||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||||
build_explicit_enc_dec_prompt, embeds_inputs,
|
build_explicit_enc_dec_prompt, embeds_inputs,
|
||||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||||
from .registry import InputContext, InputProcessingContext
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DataPrompt",
|
"DataPrompt",
|
||||||
@ -28,6 +27,4 @@ __all__ = [
|
|||||||
"build_explicit_enc_dec_prompt",
|
"build_explicit_enc_dec_prompt",
|
||||||
"to_enc_dec_tuple_list",
|
"to_enc_dec_tuple_list",
|
||||||
"zip_enc_dec_prompts",
|
"zip_enc_dec_prompts",
|
||||||
"InputContext",
|
|
||||||
"InputProcessingContext",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,186 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
|
||||||
from typing_extensions import TypeVar
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
|
||||||
from vllm.utils import get_allowed_kwarg_only_overrides
|
|
||||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.config import ModelConfig
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
||||||
else:
|
|
||||||
ModelConfig = Any
|
|
||||||
AnyTokenizer = Any
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
|
||||||
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
|
||||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class InputContext:
|
|
||||||
"""
|
|
||||||
Contains information about the model which may be used to
|
|
||||||
modify the inputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config: ModelConfig
|
|
||||||
"""The configuration of the model."""
|
|
||||||
|
|
||||||
def get_hf_config(
|
|
||||||
self,
|
|
||||||
typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
|
|
||||||
/,
|
|
||||||
) -> _C:
|
|
||||||
"""
|
|
||||||
Get the HuggingFace configuration
|
|
||||||
(`transformers.PretrainedConfig`) of the model,
|
|
||||||
additionally checking its type.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the configuration is not of the specified type.
|
|
||||||
"""
|
|
||||||
hf_config = self.model_config.hf_config
|
|
||||||
if not isinstance(hf_config, typ):
|
|
||||||
raise TypeError("Invalid type of HuggingFace config. "
|
|
||||||
f"Expected type: {typ}, but "
|
|
||||||
f"found type: {type(hf_config)}")
|
|
||||||
|
|
||||||
return hf_config
|
|
||||||
|
|
||||||
def get_hf_image_processor_config(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get the HuggingFace image processor configuration of the model.
|
|
||||||
"""
|
|
||||||
return self.model_config.hf_image_processor_config
|
|
||||||
|
|
||||||
def get_mm_config(self):
|
|
||||||
"""
|
|
||||||
Get the multimodal config of the model.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the model is not a multimodal model.
|
|
||||||
"""
|
|
||||||
mm_config = self.model_config.multimodal_config
|
|
||||||
if mm_config is None:
|
|
||||||
raise RuntimeError("Not a multimodal model")
|
|
||||||
|
|
||||||
return mm_config
|
|
||||||
|
|
||||||
def get_hf_processor(
|
|
||||||
self,
|
|
||||||
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
|
|
||||||
/,
|
|
||||||
**kwargs: object,
|
|
||||||
) -> _P:
|
|
||||||
"""
|
|
||||||
Get the HuggingFace processor
|
|
||||||
(`transformers.ProcessorMixin`) of the model,
|
|
||||||
additionally checking its type.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the processor is not of the specified type.
|
|
||||||
"""
|
|
||||||
return cached_processor_from_config(
|
|
||||||
self.model_config,
|
|
||||||
processor_cls=typ,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_processor(
|
|
||||||
self,
|
|
||||||
typ: type[_T],
|
|
||||||
/,
|
|
||||||
**kwargs: object,
|
|
||||||
) -> _T:
|
|
||||||
"""
|
|
||||||
Initialize a HuggingFace-like processor class, merging the
|
|
||||||
keyword arguments with those in the model's configuration.
|
|
||||||
"""
|
|
||||||
mm_config = self.model_config.get_multimodal_config()
|
|
||||||
base_kwargs = mm_config.mm_processor_kwargs
|
|
||||||
if base_kwargs is None:
|
|
||||||
base_kwargs = {}
|
|
||||||
|
|
||||||
merged_kwargs = {**base_kwargs, **kwargs}
|
|
||||||
|
|
||||||
return typ(**merged_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class InputProcessingContext(InputContext):
|
|
||||||
tokenizer: AnyTokenizer
|
|
||||||
"""The tokenizer used to tokenize the inputs."""
|
|
||||||
|
|
||||||
def get_hf_processor(
|
|
||||||
self,
|
|
||||||
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
|
|
||||||
/,
|
|
||||||
**kwargs: object,
|
|
||||||
) -> _P:
|
|
||||||
return super().get_hf_processor(
|
|
||||||
typ,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def call_hf_processor(
|
|
||||||
self,
|
|
||||||
hf_processor: ProcessorMixin,
|
|
||||||
data: Mapping[str, object],
|
|
||||||
kwargs: Mapping[str, object] = {},
|
|
||||||
) -> Union[BatchFeature, JSONTree]:
|
|
||||||
"""
|
|
||||||
Call `hf_processor` on the prompt `data`
|
|
||||||
(text, image, audio...) with configurable options `kwargs`.
|
|
||||||
"""
|
|
||||||
assert callable(hf_processor)
|
|
||||||
|
|
||||||
mm_config = self.model_config.get_multimodal_config()
|
|
||||||
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
|
|
||||||
|
|
||||||
allowed_kwargs = get_allowed_kwarg_only_overrides(
|
|
||||||
hf_processor,
|
|
||||||
merged_kwargs,
|
|
||||||
requires_kw_only=False,
|
|
||||||
allow_var_kwargs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def maybe_cast_dtype(x):
|
|
||||||
# This mimics the behavior of transformers.BatchFeature
|
|
||||||
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
|
||||||
return x.to(dtype=self.model_config.dtype)
|
|
||||||
return x
|
|
||||||
|
|
||||||
try:
|
|
||||||
output = hf_processor(**data,
|
|
||||||
**allowed_kwargs,
|
|
||||||
return_tensors="pt")
|
|
||||||
# this emulates output.to(dtype=self.model_config.dtype)
|
|
||||||
if isinstance(output, BatchFeature):
|
|
||||||
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
|
|
||||||
return BatchFeature(cast_output)
|
|
||||||
|
|
||||||
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
|
||||||
|
|
||||||
logger.warning_once(
|
|
||||||
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
|
||||||
"Make sure to match the behaviour of `ProcessorMixin` when "
|
|
||||||
"implementing custom processors.")
|
|
||||||
return cast_output
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
|
||||||
f"on data={data} with kwargs={allowed_kwargs}")
|
|
||||||
|
|
||||||
raise ValueError(msg) from exc
|
|
||||||
@ -11,7 +11,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -283,7 +282,6 @@ try:
|
|||||||
op_func=_lora_expand,
|
op_func=_lora_expand,
|
||||||
mutates_args=["output_tensor"],
|
mutates_args=["output_tensor"],
|
||||||
fake_impl=_lora_expand_fake,
|
fake_impl=_lora_expand_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
lora_expand = torch.ops.vllm.lora_expand
|
lora_expand = torch.ops.vllm.lora_expand
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -237,7 +236,6 @@ try:
|
|||||||
op_func=_lora_shrink,
|
op_func=_lora_shrink,
|
||||||
mutates_args=["output_tensor"],
|
mutates_args=["output_tensor"],
|
||||||
fake_impl=_lora_shrink_fake,
|
fake_impl=_lora_shrink_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
lora_shrink = torch.ops.vllm.lora_shrink
|
lora_shrink = torch.ops.vllm.lora_shrink
|
||||||
|
|
||||||
|
|||||||
@ -40,8 +40,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
ssm_state_indices,
|
ssm_state_indices,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
scale,
|
scale,
|
||||||
N: tl.constexpr, # num of sequences
|
N: tl.int64, # num of sequences
|
||||||
T: tl.constexpr, # num of tokens
|
T: tl.int64, # num of tokens
|
||||||
B: tl.constexpr,
|
B: tl.constexpr,
|
||||||
H: tl.constexpr,
|
H: tl.constexpr,
|
||||||
HV: tl.constexpr,
|
HV: tl.constexpr,
|
||||||
|
|||||||
@ -0,0 +1,123 @@
|
|||||||
|
{
|
||||||
|
"triton_version": "3.4.0",
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 32,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"8192": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16384": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -98,13 +98,16 @@ def select_experts(
|
|||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
elif custom_routing_function is None:
|
elif custom_routing_function is None:
|
||||||
assert scoring_func == "softmax"
|
assert scoring_func == "softmax"
|
||||||
topk_weights = torch.nn.functional.softmax(router_logits,
|
topk_logit_vals, topk_idx = torch.topk(router_logits,
|
||||||
dim=1,
|
k=top_k,
|
||||||
dtype=torch.float32)
|
dim=-1,
|
||||||
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
|
sorted=False)
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
|
||||||
return topk_weights, topk_ids.to(torch.int32)
|
else:
|
||||||
|
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
|
||||||
|
topk_vals = (topk_logit_vals - logZ).exp()
|
||||||
|
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
|
||||||
else:
|
else:
|
||||||
return custom_routing_function(hidden_states=hidden_states,
|
return custom_routing_function(hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
|
|||||||
@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="fused_marlin_moe",
|
op_name="fused_marlin_moe",
|
||||||
op_func=fused_marlin_moe,
|
op_func=fused_marlin_moe,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=fused_marlin_moe_fake,
|
fake_impl=fused_marlin_moe_fake,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="outplace_fused_experts",
|
op_name="outplace_fused_experts",
|
||||||
op_func=outplace_fused_experts,
|
op_func=outplace_fused_experts,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=outplace_fused_experts_fake,
|
fake_impl=outplace_fused_experts_fake,
|
||||||
tags=(() if is_torch_equal_or_newer("2.7.0") else
|
tags=(() if is_torch_equal_or_newer("2.7.0") else
|
||||||
(torch.Tag.needs_fixed_stride_order, )),
|
(torch.Tag.needs_fixed_stride_order, )),
|
||||||
|
|||||||
@ -23,7 +23,7 @@ if has_triton_kernels():
|
|||||||
from triton_kernels.routing import (RoutingData, routing,
|
from triton_kernels.routing import (RoutingData, routing,
|
||||||
routing_from_bitmatrix)
|
routing_from_bitmatrix)
|
||||||
from triton_kernels.tensor import Bitmatrix
|
from triton_kernels.tensor import Bitmatrix
|
||||||
except (ModuleNotFoundError, AttributeError) as e:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to import Triton kernels. Please make sure your triton "
|
"Failed to import Triton kernels. Please make sure your triton "
|
||||||
"version is compatible. Error: %s", e)
|
"version is compatible. Error: %s", e)
|
||||||
|
|||||||
@ -69,8 +69,6 @@ else:
|
|||||||
if is_rocm_aiter_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
rocm_aiter_grouped_topk as grouped_topk)
|
rocm_aiter_grouped_topk as grouped_topk)
|
||||||
elif current_platform.is_cpu():
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
@ -2040,7 +2038,6 @@ direct_register_custom_op(
|
|||||||
op_func=moe_forward,
|
op_func=moe_forward,
|
||||||
mutates_args=["hidden_states"],
|
mutates_args=["hidden_states"],
|
||||||
fake_impl=moe_forward_fake,
|
fake_impl=moe_forward_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2071,7 +2068,6 @@ direct_register_custom_op(
|
|||||||
op_func=moe_forward_shared,
|
op_func=moe_forward_shared,
|
||||||
mutates_args=["hidden_states"],
|
mutates_args=["hidden_states"],
|
||||||
fake_impl=moe_forward_shared_fake,
|
fake_impl=moe_forward_shared_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -223,17 +223,13 @@ if current_platform.is_rocm():
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_asm_moe_tkw1",
|
op_name="rocm_aiter_asm_moe_tkw1",
|
||||||
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_fused_moe",
|
op_name="rocm_aiter_fused_moe",
|
||||||
op_func=rocm_aiter_fused_moe_impl,
|
op_func=rocm_aiter_fused_moe_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_fused_moe_fake,
|
fake_impl=rocm_aiter_fused_moe_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -241,7 +237,6 @@ if current_platform.is_rocm():
|
|||||||
op_func=rocm_aiter_topk_softmax_impl,
|
op_func=rocm_aiter_topk_softmax_impl,
|
||||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||||
fake_impl=rocm_aiter_topk_softmax_fake,
|
fake_impl=rocm_aiter_topk_softmax_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -249,7 +244,6 @@ if current_platform.is_rocm():
|
|||||||
op_func=rocm_aiter_biased_grouped_topk_impl,
|
op_func=rocm_aiter_biased_grouped_topk_impl,
|
||||||
mutates_args=["topk_weights", "topk_ids"],
|
mutates_args=["topk_weights", "topk_ids"],
|
||||||
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -257,7 +251,6 @@ if current_platform.is_rocm():
|
|||||||
op_func=rocm_aiter_grouped_topk_impl,
|
op_func=rocm_aiter_grouped_topk_impl,
|
||||||
mutates_args=["topk_weights", "topk_ids"],
|
mutates_args=["topk_weights", "topk_ids"],
|
||||||
fake_impl=rocm_aiter_grouped_topk_fake,
|
fake_impl=rocm_aiter_grouped_topk_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -103,17 +103,13 @@ if current_platform.is_rocm():
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rms_norm",
|
op_name="rocm_aiter_rms_norm",
|
||||||
op_func=rocm_aiter_rms_norm_impl,
|
op_func=rocm_aiter_rms_norm_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_rms_norm_fake,
|
fake_impl=rocm_aiter_rms_norm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
|||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
|
"UnquantizedLinearMethod",
|
||||||
"CompressedTensorsLinearMethod",
|
"CompressedTensorsLinearMethod",
|
||||||
"CompressedTensorsLinearTransformMethod",
|
"CompressedTensorsLinearTransformMethod",
|
||||||
"BitBLASLinearMethod",
|
"BitBLASLinearMethod",
|
||||||
@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
# The amount of memory allocated for the weights is
|
# The amount of memory allocated for the weights is
|
||||||
# sum(output_partition_sizes) * input_size_per_partition.
|
# sum(output_partition_sizes) * input_size_per_partition.
|
||||||
try:
|
try:
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
weight_loader = extra_weight_attrs.pop("weight_loader")
|
||||||
input_size_per_partition,
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
dtype=params_dtype),
|
sum(output_partition_sizes),
|
||||||
requires_grad=False)
|
input_size_per_partition,
|
||||||
|
dtype=params_dtype),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
"Failed to create unquantized linear weights. "
|
"Failed to create unquantized linear weights. "
|
||||||
"This may be caused by insufficient memory to allocate "
|
"This may be caused by insufficient memory to allocate "
|
||||||
"the weight.") from e
|
"the weight.") from e
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
set_weight_attrs(weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
|||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||||
|
|
||||||
@ -401,5 +400,4 @@ direct_register_custom_op(
|
|||||||
op_func=linear_attention,
|
op_func=linear_attention,
|
||||||
mutates_args=["output"],
|
mutates_args=["output"],
|
||||||
fake_impl=linear_attention_fake,
|
fake_impl=linear_attention_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
selective_scan_fn, selective_state_update)
|
selective_scan_fn, selective_state_update)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||||
|
|
||||||
@ -464,5 +463,4 @@ direct_register_custom_op(
|
|||||||
op_func=mamba_mixer,
|
op_func=mamba_mixer,
|
||||||
mutates_args=["output"],
|
mutates_args=["output"],
|
||||||
fake_impl=mamba_mixer_fake,
|
fake_impl=mamba_mixer_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||||
|
|
||||||
@ -765,5 +764,4 @@ direct_register_custom_op(
|
|||||||
op_func=mamba_mixer2,
|
op_func=mamba_mixer2,
|
||||||
mutates_args=["output"],
|
mutates_args=["output"],
|
||||||
fake_impl=mamba_mixer2_fake,
|
fake_impl=mamba_mixer2_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,7 +17,6 @@ from .mamba_ssm import softplus
|
|||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_SIZE_H': 1}),
|
|
||||||
triton.Config({'BLOCK_SIZE_H': 2}),
|
triton.Config({'BLOCK_SIZE_H': 2}),
|
||||||
triton.Config({'BLOCK_SIZE_H': 4}),
|
triton.Config({'BLOCK_SIZE_H': 4}),
|
||||||
triton.Config({'BLOCK_SIZE_H': 8}),
|
triton.Config({'BLOCK_SIZE_H': 8}),
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
|||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.short_conv_attn import (
|
from vllm.v1.attention.backends.short_conv_attn import (
|
||||||
ShortConvAttentionMetadata)
|
ShortConvAttentionMetadata)
|
||||||
@ -251,5 +250,4 @@ direct_register_custom_op(
|
|||||||
op_func=short_conv,
|
op_func=short_conv,
|
||||||
mutates_args=["output"],
|
mutates_args=["output"],
|
||||||
fake_impl=short_conv_fake,
|
fake_impl=short_conv_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||||
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
||||||
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
||||||
|
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
is_valid_flashinfer_cutlass_fused_moe)
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ __all__ = [
|
|||||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||||
"CompressedTensorsW8A8Int8MoEMethod",
|
"CompressedTensorsW8A8Int8MoEMethod",
|
||||||
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||||
"CompressedTensorsW4A4MoeMethod"
|
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
||||||
layer.moe_config)
|
layer.moe_config)
|
||||||
|
elif quant_config._is_dynamic_token_w4a8_int(weight_quant,
|
||||||
|
input_quant):
|
||||||
|
return CompressedTensorsW4A8Int8MoEMethod(quant_config,
|
||||||
|
layer.moe_config)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||||
@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
"""
|
||||||
|
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
|
||||||
|
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
|
||||||
|
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
|
||||||
|
- Bias: Same data type as original weights
|
||||||
|
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
|
||||||
|
quantized inside the kernel
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
|
self.has_bias = self.moe.has_bias
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
# Validate scheme: weights=W4 (channel or group),
|
||||||
|
# activations=dynamic TOKEN (A8)
|
||||||
|
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||||
|
aq = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
|
||||||
|
# Must be dynamic per-token activations
|
||||||
|
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
|
||||||
|
raise ValueError(
|
||||||
|
"W4A8-int MoE needs dynamic per-token activation quantization."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weight can be channel-wise (group_size=None) or group-wise
|
||||||
|
self.group_size = wq.group_size if (wq.group_size is not None) else -1
|
||||||
|
if wq.num_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"This method only supports 4-bit weights (num_bits=4).")
|
||||||
|
|
||||||
|
# CPU only
|
||||||
|
if not current_platform.is_cpu():
|
||||||
|
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
|
||||||
|
|
||||||
|
# Arm: check _dyn ops availability
|
||||||
|
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||||
|
try:
|
||||||
|
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||||
|
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||||
|
except AttributeError as err:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
|
||||||
|
install a newer build.""") from err
|
||||||
|
self.static_input_scales = False # always dynamic per token
|
||||||
|
|
||||||
|
# ---- parameter creation ----
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
# Shapes per local rank (TP/EP):
|
||||||
|
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
|
||||||
|
# w2 : [E, H, I_local] int8
|
||||||
|
# Scales:
|
||||||
|
# channel-wise: group_size=-1 -> per-output-row, single scale per row
|
||||||
|
# group-wise : group_size=g ->
|
||||||
|
# per-output-row, (in_features/g) scales
|
||||||
|
|
||||||
|
E = num_experts
|
||||||
|
H = hidden_size
|
||||||
|
IN = intermediate_size_per_partition
|
||||||
|
g = self.group_size
|
||||||
|
|
||||||
|
# Per-row scale columns
|
||||||
|
def _n_scale_cols(in_features: int) -> int:
|
||||||
|
return 1 if g == -1 else (in_features // g)
|
||||||
|
|
||||||
|
# Register unpacked int4-as-int8 weights the loader will fill.
|
||||||
|
w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(w13, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w13_weight", w13)
|
||||||
|
|
||||||
|
w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(w2, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w2_weight", w2)
|
||||||
|
|
||||||
|
# Register scales
|
||||||
|
# KleidiAI groupwise kernels accepts float32 scales
|
||||||
|
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||||
|
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||||
|
|
||||||
|
w13_s = torch.nn.Parameter(torch.ones(E,
|
||||||
|
2 * IN,
|
||||||
|
_n_scale_cols(H),
|
||||||
|
dtype=scale_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(
|
||||||
|
w13_s, {
|
||||||
|
"quant_method": "channel" if g == -1 else "group",
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_s)
|
||||||
|
|
||||||
|
w2_s = torch.nn.Parameter(torch.ones(E,
|
||||||
|
H,
|
||||||
|
_n_scale_cols(IN),
|
||||||
|
dtype=scale_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(
|
||||||
|
w2_s, {
|
||||||
|
"quant_method": "channel" if g == -1 else "group",
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_s)
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
w13_bias = torch.nn.Parameter(torch.zeros(E,
|
||||||
|
2 * IN,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_bias", w13_bias)
|
||||||
|
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_bias", w2_bias)
|
||||||
|
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Placeholders for packed weights (will be replaced after packing)
|
||||||
|
layer.register_parameter(
|
||||||
|
"w13_weight_packed",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
|
||||||
|
|
||||||
|
layer.register_parameter(
|
||||||
|
"w2_weight_packed",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
|
||||||
|
|
||||||
|
# dims for 4 bit fused matmuls
|
||||||
|
layer.w13_in_features = H
|
||||||
|
layer.w13_out_features = 2 * IN
|
||||||
|
layer.w2_in_features = IN
|
||||||
|
layer.w2_out_features = H
|
||||||
|
layer.group_size = g
|
||||||
|
|
||||||
|
# post-load packing to dyn-4bit KleidiAI kernel's format
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
E = layer.w13_weight.shape[0]
|
||||||
|
H = layer.w13_in_features
|
||||||
|
I2 = layer.w13_out_features
|
||||||
|
IN = layer.w2_in_features
|
||||||
|
g = layer.group_size
|
||||||
|
|
||||||
|
def _pack_matrix(int4_as_int8_2d: torch.Tensor,
|
||||||
|
scales_2d: torch.Tensor,
|
||||||
|
bias_1d: Optional[torch.Tensor], in_features: int,
|
||||||
|
out_features: int) -> torch.Tensor:
|
||||||
|
# int4 values are stored as int8 in [-8,7].
|
||||||
|
# Shift to unsigned nibble and pack pairs along input-dim.
|
||||||
|
tmp = int4_as_int8_2d.add(8) # [out, in]
|
||||||
|
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
|
||||||
|
torch.uint8) # [out, in//2]
|
||||||
|
|
||||||
|
# KleidiAI groupwise kernels accepts float32 scales
|
||||||
|
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||||
|
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||||
|
scales = scales_2d.to(scale_dtype)
|
||||||
|
bias = None if bias_1d is None else bias_1d.to(torch.float32)
|
||||||
|
return torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||||
|
uint8_nibbles, scales, bias, g if g != -1 else in_features,
|
||||||
|
in_features, out_features)
|
||||||
|
|
||||||
|
# Pack per expert
|
||||||
|
w13_packed_list = []
|
||||||
|
w2_packed_list = []
|
||||||
|
|
||||||
|
has_w13_bias = hasattr(layer,
|
||||||
|
"w13_bias") and layer.w13_bias is not None
|
||||||
|
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
|
||||||
|
|
||||||
|
for e in range(E):
|
||||||
|
w13_packed_list.append(
|
||||||
|
_pack_matrix(
|
||||||
|
layer.w13_weight[e], # [2I, H]
|
||||||
|
layer.w13_weight_scale[e], # [2I, H/g or 1]
|
||||||
|
layer.w13_bias[e] if has_w13_bias else None, # [2I]
|
||||||
|
H,
|
||||||
|
I2))
|
||||||
|
w2_packed_list.append(
|
||||||
|
_pack_matrix(
|
||||||
|
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
|
||||||
|
layer.w2_weight[e], # [H, IN]
|
||||||
|
layer.w2_weight_scale[e], # [H, IN/g or 1]
|
||||||
|
layer.w2_bias[e] if has_w2_bias else None, # [H]
|
||||||
|
IN,
|
||||||
|
layer.w2_out_features # in_features=IN, out_features=H
|
||||||
|
))
|
||||||
|
|
||||||
|
# each packed tensor has identical shape per expert; stack on dim 0
|
||||||
|
w13_packed = torch.stack(w13_packed_list, dim=0)
|
||||||
|
w2_packed = torch.stack(w2_packed_list, dim=0)
|
||||||
|
|
||||||
|
replace_parameter(layer, "w13_weight_packed",
|
||||||
|
torch.nn.Parameter(w13_packed, requires_grad=False))
|
||||||
|
replace_parameter(layer, "w2_weight_packed",
|
||||||
|
torch.nn.Parameter(w2_packed, requires_grad=False))
|
||||||
|
|
||||||
|
# free raw tensors/scales/bias now that they're packed into the payload.
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_weight",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_weight",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_weight_scale",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_weight_scale",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
if has_w13_bias:
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_bias",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
if has_w2_bias:
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_bias",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
# CPU dynamic 4-bit MoE path does not use modular kernels or
|
||||||
|
# fused_experts; quant config is not needed.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||||
|
assert activation in (
|
||||||
|
"silu", "swigluoai",
|
||||||
|
"swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||||
|
assert expert_map is None, """expert_map/EP not implemented
|
||||||
|
for CPU dyn-4bit MoE."""
|
||||||
|
|
||||||
|
def _act_kind(s: str) -> int:
|
||||||
|
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
|
||||||
|
if s == "swiglu":
|
||||||
|
return 0
|
||||||
|
if s == "swigluoai":
|
||||||
|
return 1
|
||||||
|
if s == "silu":
|
||||||
|
return 2
|
||||||
|
raise ValueError(f"Unknown activation '{s}'")
|
||||||
|
|
||||||
|
# Apply topk softmax on router output
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.ops._C.dynamic_4bit_int_moe(
|
||||||
|
x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed,
|
||||||
|
layer.w2_weight_packed, layer.w2_out_features,
|
||||||
|
layer.w2_in_features, layer.w13_out_features, layer.group_size,
|
||||||
|
apply_router_weight_on_input, int(_act_kind(activation)))
|
||||||
@ -4,7 +4,6 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.utils.deep_gemm import fp8_gemm_nt
|
from vllm.utils.deep_gemm import fp8_gemm_nt
|
||||||
@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="w8a8_deepgemm_block_scaled_mm",
|
op_name="w8a8_deepgemm_block_scaled_mm",
|
||||||
op_func=w8a8_deepgemm_block_scaled_mm,
|
op_func=w8a8_deepgemm_block_scaled_mm,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
|
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -161,7 +161,6 @@ try:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="_fused_mul_mat_gguf",
|
op_name="_fused_mul_mat_gguf",
|
||||||
op_func=_fused_mul_mat_gguf,
|
op_func=_fused_mul_mat_gguf,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_fused_mul_mat_gguf_fake,
|
fake_impl=_fused_mul_mat_gguf_fake,
|
||||||
)
|
)
|
||||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||||
@ -273,7 +272,6 @@ try:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="_fused_moe_gguf",
|
op_name="_fused_moe_gguf",
|
||||||
op_func=_fused_moe_gguf,
|
op_func=_fused_moe_gguf,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_fused_moe_gguf_fake,
|
fake_impl=_fused_moe_gguf_fake,
|
||||||
)
|
)
|
||||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||||
@ -319,7 +317,6 @@ try:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="_apply_gguf_embedding",
|
op_name="_apply_gguf_embedding",
|
||||||
op_func=_apply_gguf_embedding,
|
op_func=_apply_gguf_embedding,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_apply_gguf_embedding_fake,
|
fake_impl=_apply_gguf_embedding_fake,
|
||||||
)
|
)
|
||||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||||
|
|||||||
@ -51,9 +51,7 @@ if current_platform.is_rocm():
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_w8a8",
|
op_name="rocm_aiter_gemm_w8a8",
|
||||||
op_func=rocm_aiter_gemm_w8a8_impl,
|
op_func=rocm_aiter_gemm_w8a8_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 256)
|
intermediate_size_per_partition, 256)
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
elif current_platform.is_rocm() or (
|
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
|
||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 128)
|
intermediate_size_per_partition, 128)
|
||||||
hidden_size = round_up(hidden_size, 128)
|
hidden_size = round_up(hidden_size, 128)
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
|
intermediate_size_per_partition, 256)
|
||||||
|
hidden_size = round_up(hidden_size, 256)
|
||||||
else:
|
else:
|
||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 64)
|
intermediate_size_per_partition, 64)
|
||||||
|
|||||||
@ -91,9 +91,7 @@ if current_platform.is_rocm():
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
|
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||||
and current_platform.is_fp8_fnuz()):
|
and current_platform.is_fp8_fnuz()):
|
||||||
@ -132,13 +130,14 @@ def _w8a8_triton_block_scaled_mm_fake(
|
|||||||
device=qx.device)
|
device=qx.device)
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
# Note: the check can be removed when CPU torch > 2.7
|
||||||
"w8a8_triton_block_scaled_mm_func",
|
if not current_platform.is_cpu():
|
||||||
_w8a8_triton_block_scaled_mm_func,
|
direct_register_custom_op(
|
||||||
mutates_args=[],
|
"w8a8_triton_block_scaled_mm_func",
|
||||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
_w8a8_triton_block_scaled_mm_func,
|
||||||
dispatch_key="CUDA",
|
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||||
)
|
dispatch_key="CUDA",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO fix ROCm->Triton custom path:
|
# TODO fix ROCm->Triton custom path:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
|||||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||||
from triton_kernels.tensor_details import layout
|
from triton_kernels.tensor_details import layout
|
||||||
from triton_kernels.tensor_details.layout import StridedLayout
|
from triton_kernels.tensor_details.layout import StridedLayout
|
||||||
|
|
||||||
|
value_layout_opts: dict[str, Any] = {}
|
||||||
|
scale_layout_opts: dict[str, Any] = {}
|
||||||
|
|
||||||
if (current_platform.is_cuda()
|
if (current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(90)
|
and current_platform.is_device_capability(90)
|
||||||
and not is_torch_equal_or_newer("2.8.1")):
|
and not is_torch_equal_or_newer("2.8.1")):
|
||||||
@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
|||||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||||
"this cause swizling to be disabled, which may "
|
"this cause swizling to be disabled, which may "
|
||||||
"cause performance degradation. Please upgrade to torch nightly")
|
"cause performance degradation. Please upgrade to torch nightly")
|
||||||
value_layout, value_layout_opts = StridedLayout, dict()
|
value_layout = StridedLayout
|
||||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
scale_layout = StridedLayout
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout,
|
||||||
|
StridedLayout)
|
||||||
|
|
||||||
|
from vllm.platforms.rocm import on_gfx950
|
||||||
|
value_layout = StridedLayout
|
||||||
|
scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
|
||||||
else:
|
else:
|
||||||
value_layout, value_layout_opts = \
|
value_layout, value_layout_opts = \
|
||||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||||
@ -113,7 +124,6 @@ try:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="dequant_mxfp4",
|
op_name="dequant_mxfp4",
|
||||||
op_func=_dequant_mxfp4,
|
op_func=_dequant_mxfp4,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_dequant_mxfp4_fake,
|
fake_impl=_dequant_mxfp4_fake,
|
||||||
)
|
)
|
||||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||||
@ -124,7 +134,6 @@ try:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="quant_dequant_mxfp4",
|
op_name="quant_dequant_mxfp4",
|
||||||
op_func=_quant_dequant_mxfp4,
|
op_func=_quant_dequant_mxfp4,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_quant_dequant_mxfp4_fake,
|
fake_impl=_quant_dequant_mxfp4_fake,
|
||||||
)
|
)
|
||||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||||
|
|||||||
@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
||||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -147,5 +147,4 @@ direct_register_custom_op(
|
|||||||
op_func=_flashinfer_rotary_embedding,
|
op_func=_flashinfer_rotary_embedding,
|
||||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_unquantized_gemm_impl",
|
op_name="rocm_unquantized_gemm_impl",
|
||||||
op_func=rocm_unquantized_gemm_impl,
|
op_func=rocm_unquantized_gemm_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
|
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Annotated, Literal, Optional, Union, cast
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
def _image_pixels_to_features(
|
||||||
pixel_values: torch.Tensor,
|
self,
|
||||||
**kwargs) -> torch.Tensor:
|
vision_tower: SiglipVisionModel,
|
||||||
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
pixel_values: torch.Tensor,
|
||||||
image_features = vision_tower(pixel_values.to(dtype=target_dtype),
|
**kwargs,
|
||||||
**kwargs)
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
|
target_dtype: torch.dtype = \
|
||||||
|
vision_tower.get_input_embeddings().weight.dtype
|
||||||
|
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
|
||||||
|
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
|
||||||
|
|
||||||
def select_features(leaf: torch.Tensor):
|
def select_features(leaf: torch.Tensor):
|
||||||
return self._select_image_features(
|
return self._select_image_features(
|
||||||
@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
strategy=self.config.vision_feature_select_strategy,
|
strategy=self.config.vision_feature_select_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(
|
return json_map_leaves(select_features, image_features)
|
||||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
|
||||||
json_map_leaves(select_features, image_features),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
strategy: str) -> torch.Tensor:
|
strategy: str) -> torch.Tensor:
|
||||||
|
|||||||
@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
||||||
config = vllm_config.model_config
|
|
||||||
config.max_seq_len_to_capture = config.max_model_len
|
|
||||||
logger.info(
|
|
||||||
"Setting max_seq_len_to_capture to %d "
|
|
||||||
"to ensure that CUDA graph capture "
|
|
||||||
"covers sequences of length up to max_model_len.",
|
|
||||||
config.max_model_len)
|
|
||||||
|
|
||||||
|
|
||||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
|
||||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||||
"MambaForCausalLM": MambaModelConfig,
|
"MambaForCausalLM": MambaModelConfig,
|
||||||
"Mamba2ForCausalLM": MambaModelConfig,
|
"Mamba2ForCausalLM": MambaModelConfig,
|
||||||
|
|||||||
@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import cdiv, direct_register_custom_op
|
||||||
|
|
||||||
@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="sequence_parallel_chunk",
|
op_name="sequence_parallel_chunk",
|
||||||
op_func=sequence_parallel_chunk,
|
op_func=sequence_parallel_chunk,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=sequence_parallel_chunk_fake,
|
fake_impl=sequence_parallel_chunk_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from vllm.attention import Attention
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||||
GeluAndMul,
|
GeluAndMul,
|
||||||
@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||||
|
|
||||||
from .interfaces import SupportsQuant
|
from .interfaces import SupportsQuant
|
||||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
EPS = torch.tensor(torch.finfo().min)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nAltUp(nn.Module):
|
class Gemma3nAltUp(nn.Module):
|
||||||
"""Alternating updates (Altup)
|
"""Alternating updates (Altup)
|
||||||
@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
|
|||||||
return corrected_predictions
|
return corrected_predictions
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nSelfDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Includes altup embedding and self decoder layers
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layers: list[Gemma3nDecoderLayer],
|
||||||
|
layer_idx_start: int,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.layer_idx_start = layer_idx_start
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
prefix=f"{prefix}.altup_projections.{idx-1}",
|
prefix=f"{prefix}.altup_projections.{idx-1}",
|
||||||
) for idx in range(1, self.config.altup_num_inputs)
|
) for idx in range(1, self.config.altup_num_inputs)
|
||||||
])
|
])
|
||||||
self.altup_unembed_projections = nn.ModuleList([
|
|
||||||
ColumnParallelLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
gather_output=True,
|
|
||||||
return_bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
|
||||||
) for idx in range(1, self.config.altup_num_inputs)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Transformer blocks.
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
||||||
config.num_hidden_layers,
|
|
||||||
lambda prefix: Gemma3nDecoderLayer(
|
|
||||||
config, cache_config, quant_config, prefix=prefix),
|
|
||||||
prefix=f"{prefix}.layers")
|
|
||||||
self.norm = RMSNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
|
||||||
self.eps = torch.tensor(torch.finfo().min)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.embed_tokens(input_ids) * self.embed_scale
|
|
||||||
|
|
||||||
def get_per_layer_input_embeddings(
|
def get_per_layer_input_embeddings(
|
||||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
return self.embed_tokens_per_layer(
|
return self.embed_tokens_per_layer(
|
||||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||||
|
|
||||||
def forward(
|
def get_per_layer_inputs(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
hidden_states_0: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
per_layer_inputs: Optional[torch.Tensor],
|
||||||
per_layer_inputs: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
if inputs_embeds is not None:
|
|
||||||
hidden_states_0 = inputs_embeds
|
|
||||||
else:
|
|
||||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||||
per_layer_projection = per_layer_projection.reshape(
|
per_layer_projection = per_layer_projection.reshape(
|
||||||
*hidden_states_0.shape[:-1],
|
*hidden_states_0.shape[:-1],
|
||||||
@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
)
|
)
|
||||||
per_layer_projection = self.per_layer_projection_norm(
|
per_layer_projection = self.per_layer_projection_norm(
|
||||||
per_layer_projection)
|
per_layer_projection)
|
||||||
|
|
||||||
if per_layer_inputs is not None:
|
if per_layer_inputs is not None:
|
||||||
# Profiling run does not compute per_layer_inputs
|
# Profiling run does not compute per_layer_inputs
|
||||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||||
per_layer_inputs *= self.per_layer_input_scale
|
per_layer_inputs *= self.per_layer_input_scale
|
||||||
else:
|
else:
|
||||||
per_layer_inputs = per_layer_projection
|
per_layer_inputs = per_layer_projection
|
||||||
|
return per_layer_inputs
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
|
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||||
# Altup embed.
|
# Altup embed.
|
||||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||||
@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
keepdim=True)**0.5
|
keepdim=True)**0.5
|
||||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||||
new_magnitude, self.eps)
|
new_magnitude, EPS)
|
||||||
hidden_states = torch.stack(hidden_states, dim=0)
|
hidden_states = torch.stack(hidden_states, dim=-1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
# Transformer blocks.
|
def forward(
|
||||||
for layer_idx, layer in enumerate(self.layers):
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states_0 = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||||
|
hidden_states_0, per_layer_inputs)
|
||||||
|
hidden_states = self.altup_embed(hidden_states_0)
|
||||||
|
|
||||||
|
# [altnum_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = hidden_states.permute(2, 0, 1)
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.decoder_layers):
|
||||||
|
layer_idx = idx + self.layer_idx_start
|
||||||
|
# [altup_num_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [num_tokens, hidden_size, altnum_inputs]
|
||||||
|
hidden_states = hidden_states.permute(1, 2, 0)
|
||||||
|
|
||||||
|
return hidden_states, adjusted_per_layer_inputs
|
||||||
|
|
||||||
|
|
||||||
|
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nCrossDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Cross-decoder layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layers: list[Gemma3nDecoderLayer],
|
||||||
|
layer_idx_start: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.layer_idx_start = layer_idx_start
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
per_layer_inputs: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# [altnum_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = hidden_states.permute(2, 0, 1)
|
||||||
|
for idx, layer in enumerate(self.decoder_layers):
|
||||||
|
layer_idx = idx + self.layer_idx_start
|
||||||
# [altup_num_inputs, num_tokens, hidden_size]
|
# [altup_num_inputs, num_tokens, hidden_size]
|
||||||
hidden_states = layer(
|
hidden_states = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
# [num_tokens, hidden_size, altnum_inputs]
|
||||||
|
hidden_states = hidden_states.permute(1, 2, 0)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# This disables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||||
|
cache_config.kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.altup_unembed_projections = nn.ModuleList([
|
||||||
|
ColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=True,
|
||||||
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
||||||
|
) for idx in range(1, self.config.altup_num_inputs)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Allocate config.num_kv_shared_layers layers for self-decoder
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: Gemma3nDecoderLayer(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||||
|
config.num_kv_shared_layers)
|
||||||
|
|
||||||
|
# NOTE(sarckk): importing this top level seems to cause issues
|
||||||
|
# during running of tests.
|
||||||
|
from vllm.compilation.backends import set_model_tag
|
||||||
|
|
||||||
|
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
|
||||||
|
with set_model_tag("self_decoder"):
|
||||||
|
self.self_decoder = Gemma3nSelfDecoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.self_decoder",
|
||||||
|
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||||
|
layer_idx_start=0,
|
||||||
|
)
|
||||||
|
# Layer idx 20-30 are cross-decoder layers in YOCO
|
||||||
|
with set_model_tag("cross_decoder"):
|
||||||
|
self.cross_decoder = Gemma3nCrossDecoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.cross_decoder",
|
||||||
|
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||||
|
layer_idx_start=first_kv_shared_layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||||
|
|
||||||
|
if self.fast_prefill_enabled:
|
||||||
|
# Allocate static buffers for CUDAGraph
|
||||||
|
# TODO(sarckk): Extract this functionality to interface
|
||||||
|
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
self.positions = torch.zeros(max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
self.hidden_states = torch.zeros(
|
||||||
|
(max_num_tokens, config.hidden_size,
|
||||||
|
self.config.altup_num_inputs),
|
||||||
|
dtype=self.embed_tokens.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.per_layer_inputs = torch.zeros(
|
||||||
|
(max_num_tokens, self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size_per_layer_input),
|
||||||
|
dtype=self.embed_tokens.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed_tokens(self):
|
||||||
|
return self.self_decoder.embed_tokens
|
||||||
|
|
||||||
|
def get_per_layer_input_embeddings(
|
||||||
|
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.self_decoder.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def fast_prefill_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits_indices_padded, num_logits_indices = None, None
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
|
# attn_metadata is None during dummy runs
|
||||||
|
if (self.fast_prefill_enabled and attn_metadata is not None):
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
# Last layer is a KV sharing layer
|
||||||
|
layer_attn_metadata = attn_metadata[
|
||||||
|
self.layers[-1].self_attn.attn.layer_name]
|
||||||
|
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
|
||||||
|
logits_indices_padded = (
|
||||||
|
layer_attn_metadata.logits_indices_padded)
|
||||||
|
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||||
|
|
||||||
|
# Copy inputs for cudagraph
|
||||||
|
batch_size = positions.size(0)
|
||||||
|
self.positions[:batch_size].copy_(positions)
|
||||||
|
self_decoder_hidden_states, per_layer_inputs_adjusted = \
|
||||||
|
self.self_decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=self.positions[:batch_size],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if logits_indices_padded is None:
|
||||||
|
logits_indices_padded = torch.arange(
|
||||||
|
positions.size(0),
|
||||||
|
dtype=positions.dtype,
|
||||||
|
device=positions.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(sarckk): There is currently a bug caused by
|
||||||
|
# vLLM converting output of last piecewise CUDA graph
|
||||||
|
# to weakref, causing memory to be prematurely freed
|
||||||
|
# when there are multiple compilation units
|
||||||
|
# Keep .clone() until fix in
|
||||||
|
# https://github.com/vllm-project/vllm/pull/22282
|
||||||
|
hidden_states = self_decoder_hidden_states.clone()
|
||||||
|
|
||||||
|
# Copy inputs for cudagraph
|
||||||
|
num_padded_logits_indices = logits_indices_padded.size(0)
|
||||||
|
self.positions[:num_padded_logits_indices].copy_(
|
||||||
|
positions[logits_indices_padded])
|
||||||
|
self.hidden_states[:num_padded_logits_indices].copy_(
|
||||||
|
self_decoder_hidden_states[logits_indices_padded])
|
||||||
|
self.per_layer_inputs[:num_padded_logits_indices].copy_(
|
||||||
|
per_layer_inputs_adjusted[logits_indices_padded])
|
||||||
|
cross_decoder_hidden_states = self.cross_decoder(
|
||||||
|
positions=self.positions[:num_padded_logits_indices],
|
||||||
|
hidden_states=self.hidden_states[:num_padded_logits_indices],
|
||||||
|
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_indices is not None:
|
||||||
|
assert num_logits_indices > 0
|
||||||
|
# Merge cross-decoder and self-decoder hidden states
|
||||||
|
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||||
|
cross_decoder_hidden_states[:num_logits_indices])
|
||||||
|
else:
|
||||||
|
hidden_states = cross_decoder_hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def normal_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states, per_layer_inputs = self.self_decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.cross_decoder(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def altup_unembed(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
# Altup unembed.
|
# Altup unembed.
|
||||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
target_magnitude = torch.mean(hidden_states[..., 0]**2,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
keepdim=True)**0.5
|
keepdim=True)**0.5
|
||||||
for i in range(1, self.config.altup_num_inputs):
|
for i in range(1, self.config.altup_num_inputs):
|
||||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
|
||||||
hidden_states[i])
|
hidden_states[..., i])
|
||||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
new_magnitude = torch.mean(hidden_states[..., i]**2,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
keepdim=True)**0.5
|
keepdim=True)**0.5
|
||||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
hidden_states[..., i] *= target_magnitude / torch.maximum(
|
||||||
new_magnitude, self.eps)
|
new_magnitude, EPS)
|
||||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
|
||||||
hidden_states = torch.mean(hidden_states, dim=0)
|
hidden_states = torch.mean(hidden_states, dim=-1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if self.fast_prefill_enabled:
|
||||||
|
hidden_states = self.fast_prefill_forward(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
inputs_embeds,
|
||||||
|
per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.normal_forward(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
inputs_embeds,
|
||||||
|
per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.altup_unembed(hidden_states)
|
||||||
return self.norm(hidden_states)
|
return self.norm(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
# decoder layer weights, altup_unembed_projections and rmsnorm
|
||||||
|
# are initialized in text model, others are in self decoder
|
||||||
|
if (not name.startswith('layers')
|
||||||
|
and not name.startswith('altup_unembed_projections')
|
||||||
|
and not name.startswith('norm')):
|
||||||
|
name = f"self_decoder.{name}"
|
||||||
|
|
||||||
if (self.quant_config is not None and
|
if (self.quant_config is not None and
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
|||||||
@ -308,13 +308,11 @@ class GraniteModel(nn.Module):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
residual = None
|
|
||||||
|
|
||||||
hidden_states *= self.config.embedding_multiplier
|
hidden_states *= self.config.embedding_multiplier
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
|
||||||
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
hidden_states = layer(positions, hidden_states)
|
hidden_states = layer(positions, hidden_states)
|
||||||
@ -322,7 +320,6 @@ class GraniteModel(nn.Module):
|
|||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
"residual": residual
|
|
||||||
})
|
})
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
@ -475,10 +472,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
torch.zeros((batch_size, self.config.hidden_size),
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device),
|
device=device),
|
||||||
"residual":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
|||||||
@ -298,17 +298,14 @@ class GraniteMoeModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
hidden_states *= self.embedding_multiplier
|
hidden_states *= self.embedding_multiplier
|
||||||
residual = None
|
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
hidden_states = layer(positions, hidden_states)
|
hidden_states = layer(positions, hidden_states)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
"residual": residual
|
|
||||||
})
|
})
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -523,10 +520,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
torch.zeros((batch_size, self.config.hidden_size),
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device),
|
device=device),
|
||||||
"residual":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
|||||||
@ -195,17 +195,14 @@ class GraniteMoeSharedModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
hidden_states *= self.embedding_multiplier
|
hidden_states *= self.embedding_multiplier
|
||||||
residual = None
|
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
hidden_states = layer(positions, hidden_states)
|
hidden_states = layer(positions, hidden_states)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
"residual": residual
|
|
||||||
})
|
})
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -323,10 +320,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
torch.zeros((batch_size, self.config.hidden_size),
|
torch.zeros((batch_size, self.config.hidden_size),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device),
|
device=device),
|
||||||
"residual":
|
|
||||||
torch.zeros((batch_size, self.config.hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
|||||||
@ -29,7 +29,6 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
|||||||
from transformers.modeling_utils import no_init_weights
|
from transformers.modeling_utils import no_init_weights
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||||
@ -37,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|||||||
MultiModalKwargsItems)
|
MultiModalKwargsItems)
|
||||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo,
|
||||||
PromptUpdate)
|
InputProcessingContext,
|
||||||
|
PromptReplacement, PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
|
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
|
||||||
Union, cast)
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor
|
|||||||
from transformers.models.pixtral import PixtralProcessor
|
from transformers.models.pixtral import PixtralProcessor
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -28,8 +27,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|||||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
ImageSize, MultiModalDataItems)
|
ImageSize, MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
InputProcessingContext,
|
||||||
|
PromptReplacement, PromptUpdate,
|
||||||
|
PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.jsontree import json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
@ -622,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
# NOTE: we skip the step to select the vision feature layer since
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
# this is already done inside the vision tower
|
# this is already done inside the vision tower
|
||||||
image_features = vision_tower(pixel_values)
|
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
|
||||||
|
vision_tower(pixel_values)
|
||||||
|
|
||||||
def select_features(leaf: torch.Tensor):
|
def select_features(leaf: torch.Tensor):
|
||||||
return self._select_image_features(
|
return self._select_image_features(
|
||||||
@ -630,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
strategy=self.config.vision_feature_select_strategy,
|
strategy=self.config.vision_feature_select_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cast(
|
return json_map_leaves(select_features, image_features)
|
||||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
|
||||||
json_map_leaves(select_features, image_features),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_image_pixels(
|
def _process_image_pixels(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user