mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 16:17:05 +08:00
Merge branch 'main' into woosuk/model-runner-v2
This commit is contained in:
commit
ad2cf805ad
59
.buildkite/scripts/run-prime-rl-test.sh
Executable file
59
.buildkite/scripts/run-prime-rl-test.sh
Executable file
@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Setup script for Prime-RL integration tests
|
||||
# This script prepares the environment for running Prime-RL tests with nightly vLLM
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
|
||||
PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git"
|
||||
PRIME_RL_DIR="${REPO_ROOT}/prime-rl"
|
||||
|
||||
echo "Setting up Prime-RL integration test environment..."
|
||||
|
||||
# Clean up any existing Prime-RL directory
|
||||
if [ -d "${PRIME_RL_DIR}" ]; then
|
||||
echo "Removing existing Prime-RL directory..."
|
||||
rm -rf "${PRIME_RL_DIR}"
|
||||
fi
|
||||
|
||||
# Install UV if not available
|
||||
if ! command -v uv &> /dev/null; then
|
||||
echo "Installing UV package manager..."
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
fi
|
||||
|
||||
# Clone Prime-RL repository at specific branch for reproducible tests
|
||||
PRIME_RL_BRANCH="integ-vllm-main"
|
||||
echo "Cloning Prime-RL repository at branch: ${PRIME_RL_BRANCH}..."
|
||||
git clone --branch "${PRIME_RL_BRANCH}" --single-branch "${PRIME_RL_REPO}" "${PRIME_RL_DIR}"
|
||||
cd "${PRIME_RL_DIR}"
|
||||
|
||||
echo "Setting up UV project environment..."
|
||||
export UV_PROJECT_ENVIRONMENT=/usr/local
|
||||
ln -s /usr/bin/python3 /usr/local/bin/python
|
||||
|
||||
# Remove vllm pin from pyproject.toml
|
||||
echo "Removing vllm pin from pyproject.toml..."
|
||||
sed -i '/vllm==/d' pyproject.toml
|
||||
|
||||
# Sync Prime-RL dependencies
|
||||
echo "Installing Prime-RL dependencies..."
|
||||
uv sync --inexact && uv sync --inexact --all-extras
|
||||
|
||||
# Verify installation
|
||||
echo "Verifying installations..."
|
||||
uv run python -c "import vllm; print(f'vLLM version: {vllm.__version__}')"
|
||||
uv run python -c "import prime_rl; print('Prime-RL imported successfully')"
|
||||
|
||||
echo "Prime-RL integration test environment setup complete!"
|
||||
|
||||
echo "Running Prime-RL integration tests..."
|
||||
export WANDB_MODE=offline # this makes this test not require a WANDB_API_KEY
|
||||
uv run pytest -vs tests/integration/test_rl.py -m gpu
|
||||
|
||||
echo "Prime-RL integration tests completed!"
|
||||
@ -887,6 +887,8 @@ steps:
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
- vllm/v1/worker/
|
||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
@ -908,6 +910,7 @@ steps:
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
timeout_in_minutes: 60
|
||||
@ -1042,3 +1045,15 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
|
||||
|
||||
##### RL Integration Tests #####
|
||||
- label: Prime-RL Integration Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
working_dir: "/vllm-workspace"
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- .buildkite/scripts/run-prime-rl-test.sh
|
||||
commands:
|
||||
- bash .buildkite/scripts/run-prime-rl-test.sh
|
||||
|
||||
@ -449,7 +449,8 @@ async def benchmark(
|
||||
def prepare_extra_body(request) -> dict:
|
||||
extra_body = {}
|
||||
# Add the schema to the extra_body
|
||||
extra_body[request.structure_type] = request.schema
|
||||
extra_body["structured_outputs"] = {}
|
||||
extra_body["structured_outputs"][request.structure_type] = request.schema
|
||||
return extra_body
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
|
||||
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe
|
||||
kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
|
||||
but use different quantization strategies and backends.
|
||||
"""
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Weight shapes for different models: [num_experts, topk, hidden_size,
|
||||
# intermediate_size]
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"mixtral-8x7b": [
|
||||
[8, 2, 4096, 14336],
|
||||
],
|
||||
"deepseek-v2": [
|
||||
[160, 6, 5120, 12288],
|
||||
],
|
||||
"custom-small": [
|
||||
[8, 2, 2048, 7168],
|
||||
],
|
||||
"glm45-fp8": [
|
||||
[128, 8, 4096, 1408],
|
||||
],
|
||||
"Llama-4-Maverick-17B-128E-Instruct-FP8": [
|
||||
[128, 1, 5120, 8192],
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"mixtral-8x7b",
|
||||
]
|
||||
|
||||
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False, True]
|
||||
PER_OUT_CH_OPTS = [False, True]
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def bench_run(
|
||||
results: list,
|
||||
model: str,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
mkn: tuple[int, int, int],
|
||||
):
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
device = "cuda"
|
||||
|
||||
# Create input activations
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
|
||||
# Create weights
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
# Create FP8 quantized weights and scales for both kernels
|
||||
w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE)
|
||||
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE)
|
||||
|
||||
# Create scales based on quantization strategy
|
||||
if per_out_ch:
|
||||
# Per-channel quantization
|
||||
w1_scale = torch.empty(
|
||||
(num_experts, 2 * n, 1), device=device, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Per-tensor quantization
|
||||
w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
# Quantize weights
|
||||
for expert in range(num_experts):
|
||||
if per_out_ch:
|
||||
# Per-channel quantization - not yet implemented properly
|
||||
# For now, fall back to per-tensor quantization
|
||||
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||
# Expand scalar scales to the expected per-channel shape
|
||||
w1_scale[expert] = w1_scale_temp.expand(2 * n, 1)
|
||||
w2_scale[expert] = w2_scale_temp.expand(k, 1)
|
||||
else:
|
||||
# Per-tensor quantization
|
||||
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||
# Store scalar scales in [1, 1] tensors
|
||||
w1_scale[expert, 0, 0] = w1_scale_temp
|
||||
w2_scale[expert, 0, 0] = w2_scale_temp
|
||||
|
||||
# Prepare weights for CUTLASS (no transpose needed)
|
||||
w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K]
|
||||
w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N]
|
||||
|
||||
# Create router scores and get topk
|
||||
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
# WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization
|
||||
# Force per-tensor quantization for all cases to match working e2e setup
|
||||
a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||
a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||
|
||||
# Force per-tensor quantization for all cases
|
||||
per_act_token = False
|
||||
|
||||
# Create stride tensors for CUTLASS
|
||||
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
|
||||
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
|
||||
cutlass_moe_fp8(
|
||||
a=a,
|
||||
w1_q=w1,
|
||||
w2_q=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
|
||||
# Pre-create quantization config to avoid creating it inside CUDA graph
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
# Capture 10 invocations like benchmark_moe.py
|
||||
for _ in range(10):
|
||||
cutlass_moe_fp8(
|
||||
a=a,
|
||||
w1_q=w1_fp8q_cutlass,
|
||||
w2_q=w2_fp8q_cutlass,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
# Capture 10 invocations like benchmark_moe.py
|
||||
for _ in range(10):
|
||||
fused_experts(
|
||||
a,
|
||||
w1_fp8q,
|
||||
w2_fp8q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
|
||||
"""Benchmark CUDA graph using events like benchmark_moe.py"""
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies = []
|
||||
for _ in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
|
||||
# Divide by 10 since graph contains 10 calls
|
||||
return sum(latencies) / (num_iters * 10)
|
||||
|
||||
# Benchmark parameters
|
||||
num_warmup = 5
|
||||
num_iters = 100
|
||||
|
||||
# Benchmark only CUDA graphs (more reliable and faster)
|
||||
# Benchmark Triton MoE with CUDA graphs
|
||||
triton_graph_time = bench_cuda_graph(
|
||||
triton_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||
)
|
||||
|
||||
# Benchmark CUTLASS MoE with CUDA graphs
|
||||
cutlass_graph_time = bench_cuda_graph(
|
||||
cutlass_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||
)
|
||||
|
||||
# Convert ms to us and return results
|
||||
triton_time_us = triton_graph_time * 1000
|
||||
cutlass_time_us = cutlass_graph_time * 1000
|
||||
|
||||
return {
|
||||
"batch_size": m,
|
||||
"triton_time_us": triton_time_us,
|
||||
"cutlass_time_us": cutlass_time_us,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
all_results = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in args.per_act_token_opts:
|
||||
for per_out_ch in args.per_out_ch_opts:
|
||||
print(
|
||||
f"\n=== {model}, experts={num_experts}, topk={topk},"
|
||||
f"per_act={per_act_token}, per_out_ch={per_out_ch} ==="
|
||||
)
|
||||
|
||||
config_results = []
|
||||
for size_m in args.batch_sizes:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
result = bench_run(
|
||||
[], # Not used anymore
|
||||
model,
|
||||
num_experts,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
mkn,
|
||||
)
|
||||
if result:
|
||||
config_results.append(result)
|
||||
|
||||
# Print results table for this configuration
|
||||
if config_results:
|
||||
print(
|
||||
f"\n{'Batch Size':<12}"
|
||||
f"{'Triton (us)':<15}"
|
||||
f"{'CUTLASS (us)':<15}"
|
||||
)
|
||||
print("-" * 45)
|
||||
for result in config_results:
|
||||
print(
|
||||
f"{result['batch_size']:<12}"
|
||||
f"{result['triton_time_us']:<15.2f}"
|
||||
f"{result['cutlass_time_us']:<15.2f}"
|
||||
)
|
||||
|
||||
all_results.extend(config_results)
|
||||
|
||||
print(f"\nTotal benchmarks completed: {len(all_results)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE
|
||||
across specified models/shapes/batches
|
||||
|
||||
Example usage:
|
||||
python benchmark_cutlass_moe_fp8.py \
|
||||
--model "Llama-4-Maverick-17B-128E-Instruct-FP8" \
|
||||
--tp-sizes 8 \
|
||||
--batch-size 2 4 8 \
|
||||
--per-act-token-opts false \
|
||||
--per-out-ch-opts false
|
||||
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument(
|
||||
"--per-act-token-opts",
|
||||
nargs="+",
|
||||
type=lambda x: x.lower() == "true",
|
||||
default=[False, True],
|
||||
help="Per-activation token quantization options (true/false)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-out-ch-opts",
|
||||
nargs="+",
|
||||
type=lambda x: x.lower() == "true",
|
||||
default=[False, True],
|
||||
help="Per-output channel quantization options (true/false)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
|
||||
@ -135,10 +135,10 @@ public:
|
||||
max_splits = min(16, max_splits);
|
||||
|
||||
// TODO: This avoids a hang when the batch size larger than 1 and
|
||||
// there is more than 4 kv_splits.
|
||||
// there is more than 1 kv_splits.
|
||||
// Discuss with NVIDIA how this can be fixed.
|
||||
if (B > 1) {
|
||||
max_splits = min(2, max_splits);
|
||||
max_splits = min(1, max_splits);
|
||||
}
|
||||
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
|
||||
@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
|
||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||
|
||||
ops.def(
|
||||
"dynamic_4bit_int_moe("
|
||||
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
|
||||
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
|
||||
"int group_size, bool apply_router_weight_on_input, int activation_kind"
|
||||
") -> Tensor");
|
||||
|
||||
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
|
||||
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
|
||||
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
@ -0,0 +1,156 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// _dyn_quant_matmul_4bit is only available on AArch64.
|
||||
#if defined(__aarch64__)
|
||||
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
|
||||
#endif
|
||||
|
||||
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
|
||||
int64_t group_size_eff, int64_t in_features,
|
||||
int64_t out_features) {
|
||||
#if defined(__aarch64__)
|
||||
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
|
||||
in_features, out_features);
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
|
||||
"_dyn_quant_matmul_4bit is unavailable on this architecture");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
enum ActivationKind : int64_t {
|
||||
SwiGLU_Gu = 0, // act = SiLU(g) * u
|
||||
SwiGLUOAI = 1, // act = SiLU(u) * g
|
||||
SiLU = 2 // SiLU
|
||||
};
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||
int64_t activation_kind) {
|
||||
TORCH_CHECK(x.dim() == 2, "x must be 2D");
|
||||
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
|
||||
"topk tensors must be [T, K]");
|
||||
TORCH_CHECK(
|
||||
w13_packed.size(0) == w2_packed.size(0),
|
||||
"w13_packed and w2_packed must have same number of experts in dim 0");
|
||||
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
|
||||
|
||||
const int64_t T = x.size(0);
|
||||
const int64_t K = topk_ids.size(1);
|
||||
const int64_t E = w13_packed.size(0);
|
||||
const int64_t N = T * K;
|
||||
|
||||
auto x_c = x.contiguous();
|
||||
auto ids_c = topk_ids.contiguous();
|
||||
auto gates_c = topk_weights.to(at::kFloat).contiguous();
|
||||
|
||||
// bucketing tokens -> experts
|
||||
c10::SmallVector<int64_t, 64> counts(
|
||||
E, 0); // Small vector uses stack allocation
|
||||
{
|
||||
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||
for (int64_t i = 0; i < N; ++i) {
|
||||
const int64_t e_id = ids_ptr[i];
|
||||
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
|
||||
counts[e_id]++;
|
||||
}
|
||||
}
|
||||
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
|
||||
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
|
||||
|
||||
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
|
||||
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
|
||||
{
|
||||
c10::SmallVector<int64_t, 64> cursor(E, 0);
|
||||
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||
const auto* gts_ptr = gates_c.data_ptr<float>();
|
||||
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
|
||||
auto* gate_ptr = expert_gates.data_ptr<float>();
|
||||
|
||||
for (int64_t t = 0; t < T; ++t) {
|
||||
const int64_t base = t * K;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const int64_t idx = base + k;
|
||||
const int64_t e = ids_ptr[idx];
|
||||
const int64_t p = offsets[e] + (cursor[e]++);
|
||||
tok_ptr[p] = t;
|
||||
gate_ptr[p] = gts_ptr[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
|
||||
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
|
||||
|
||||
// Per-expert outputs filled in parallel
|
||||
std::vector<torch::Tensor> y_list(E);
|
||||
y_list.resize(E);
|
||||
|
||||
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
|
||||
for (int64_t e = e_begin; e < e_end; ++e) {
|
||||
const int64_t te = counts[e];
|
||||
if (te == 0) {
|
||||
y_list[e] = at::empty({0, H}, x_c.options());
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t start = offsets[e];
|
||||
|
||||
auto sel_tokens =
|
||||
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
auto gates_e =
|
||||
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
|
||||
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
|
||||
|
||||
if (apply_router_weight_on_input) {
|
||||
x_e = x_e.mul(gates_e.unsqueeze(1));
|
||||
}
|
||||
|
||||
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||
auto w2_e = w2_packed.select(/*dim=*/0, e);
|
||||
|
||||
// W13
|
||||
auto y13 =
|
||||
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
|
||||
|
||||
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
|
||||
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
|
||||
|
||||
torch::Tensor act;
|
||||
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
|
||||
constexpr double kAlpha = 1.702; // GPT-OSS default
|
||||
constexpr double kLimit = 7.0; // GPT-OSS default
|
||||
auto gate_c = at::clamp_max(g_part, kLimit);
|
||||
auto up_c = at::clamp(u_part, -kLimit, kLimit);
|
||||
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
|
||||
act = up_c.add(1.0).mul(glu);
|
||||
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
|
||||
act = at::silu(g_part).mul(u_part);
|
||||
}
|
||||
|
||||
// W2
|
||||
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
|
||||
|
||||
if (!apply_router_weight_on_input) {
|
||||
y = y.mul(gates_e.unsqueeze(1));
|
||||
}
|
||||
|
||||
// Store per-expert result
|
||||
y_list[e] = y;
|
||||
}
|
||||
});
|
||||
|
||||
// Concatenate all expert outputs to match expert_tokens order
|
||||
auto Y_all = at::cat(y_list, /*dim=*/0);
|
||||
auto out = at::zeros({T, H}, x.options());
|
||||
out =
|
||||
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||
int64_t activation_kind);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
||||
@ -23,9 +23,14 @@
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
||||
|
||||
#if defined(HIP_FP8_TYPE_OCP)
|
||||
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
||||
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
||||
#else
|
||||
// ROCm 6.2 fallback: only *_fnuz types exist
|
||||
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
|
||||
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
@ -25,6 +25,12 @@
|
||||
#include "../attention/dtype_fp8.cuh"
|
||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
||||
|
||||
// ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent
|
||||
#if !defined(HIP_FP8_TYPE_OCP)
|
||||
using __hip_fp8_e4m3 = __hip_fp8_e4m3_fnuz;
|
||||
using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && \
|
||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__GFX9__
|
||||
|
||||
@ -34,7 +34,7 @@ Now supports 5 types of connectors:
|
||||
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'
|
||||
```
|
||||
|
||||
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
|
||||
|
||||
@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
|
||||
|
||||
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
|
||||
|
||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
|
||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'`
|
||||
|
||||
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ outlines_core == 0.2.11
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.24; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
|
||||
@ -1079,7 +1079,7 @@ def dummy_llava_path():
|
||||
local_dir=_dummy_llava_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack"
|
||||
"*.msgpack", "*.safetensors"
|
||||
])
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
@ -1098,7 +1098,7 @@ def dummy_gemma2_embedding_path():
|
||||
local_dir=_dummy_gemma2_embedding_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack"
|
||||
"*.msgpack", "*.safetensors"
|
||||
])
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path) as f:
|
||||
|
||||
@ -382,7 +382,6 @@ def test_tp_language_generation(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
pytest.skip("Skipping the test until V1 passes it.")
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
@ -410,7 +409,6 @@ def test_tp_language_embedding(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
pytest.skip("Skipping the test until V1 passes it.")
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
@ -438,7 +436,6 @@ def test_tp_multimodal_generation(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
pytest.skip("Skipping the test until V1 passes it.")
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
|
||||
@ -523,6 +523,7 @@ async def test_function_calling(client: OpenAI, model_name: str):
|
||||
input="What's the weather like in Paris today?",
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
extra_body={"request_id": "test_function_calling_non_resp"},
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
|
||||
@ -194,6 +194,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||
args1 = tc.function.arguments
|
||||
assert args1 is not None and len(args1) > 0
|
||||
assert not first_msg.content
|
||||
|
||||
messages.append({"role": "assistant", "content": args1})
|
||||
messages.append({
|
||||
|
||||
@ -85,8 +85,7 @@ def test_env(
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size,
|
||||
False)
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
elif device == "hip":
|
||||
@ -106,7 +105,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
@ -117,7 +115,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
@ -127,7 +124,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -136,7 +132,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "TRITON_ATTN_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -164,7 +159,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "CUTLASS_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -179,7 +173,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
@ -199,7 +192,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -208,7 +200,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
@ -218,7 +209,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "TRITON_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -227,7 +217,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASHINFER_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -236,7 +225,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
@ -245,7 +233,6 @@ def test_env(
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert backend.get_name() == "FLEX_ATTENTION", (
|
||||
"Should fallback to FlexAttention if head size is "
|
||||
@ -264,13 +251,13 @@ def test_fp32_fallback(
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
@ -286,29 +273,29 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(torch.cuda,
|
||||
"get_device_capability",
|
||||
lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Reset the monkeypatch for subsequent tests
|
||||
monkeypatch.undo()
|
||||
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
original_module = sys.modules.get('vllm_flash_attn')
|
||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Restore the original module if it existed
|
||||
@ -319,11 +306,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, True)
|
||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
@ -336,5 +319,5 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(32, torch.float16, None, 16, False)
|
||||
get_attn_backend(32, torch.float16, None, 16)
|
||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
||||
|
||||
@ -12,11 +12,11 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens)
|
||||
|
||||
@ -18,10 +18,10 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext)
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@ -42,7 +42,6 @@ def test_oot_registration_text_generation(
|
||||
assert rest == ""
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is skipped because it failed on V1.")
|
||||
@create_new_process_for_each_test()
|
||||
def test_oot_registration_embedding(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@ -63,7 +62,6 @@ def test_oot_registration_embedding(
|
||||
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test is skipped because it failed on V1.")
|
||||
@create_new_process_for_each_test()
|
||||
def test_oot_registration_multimodal(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@ -11,8 +11,9 @@ import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.multimodal.processing import InputProcessingContext
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
@ -264,7 +265,7 @@ def build_model_context(
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
mm_processor_cache_gb: int = 0,
|
||||
):
|
||||
"""Creates an InputContext for a given model.
|
||||
"""Creates an InputProcessingContext for a given model.
|
||||
|
||||
Args:
|
||||
model_id: ID of the model being considered.
|
||||
@ -273,7 +274,7 @@ def build_model_context(
|
||||
limit_mm_per_prompt: Multimodal limits.
|
||||
|
||||
Returns:
|
||||
InputContext for the model being considered.
|
||||
InputProcessingContext for the model being considered.
|
||||
"""
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@ -298,7 +299,11 @@ def build_model_context(
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
**model_config_kwargs,
|
||||
)
|
||||
return InputContext(model_config)
|
||||
|
||||
return InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
|
||||
|
||||
def check_embeddings_close(
|
||||
|
||||
@ -8,11 +8,11 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptIndexTargets, PromptInsertion,
|
||||
PromptReplacement, apply_text_matches,
|
||||
apply_token_matches,
|
||||
|
||||
392
tests/reasoning/test_base_thinking_reasoning_parser.py
Normal file
392
tests/reasoning/test_base_thinking_reasoning_parser.py
Normal file
@ -0,0 +1,392 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
|
||||
|
||||
# Create a concrete test implementation of BaseThinkingReasoningParser
|
||||
class TestThinkingReasoningParser(BaseThinkingReasoningParser):
|
||||
"""Test implementation of BaseThinkingReasoningParser."""
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
return "<test:think>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
return "</test:think>"
|
||||
|
||||
|
||||
class TestThinkingReasoningParserAlt(BaseThinkingReasoningParser):
|
||||
"""Alternative test implementation with different tokens."""
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
return "<alt:start>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
return "<alt:end>"
|
||||
|
||||
|
||||
# Use a test model
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
# Add custom test tokens
|
||||
test_tokens = ["<test:think>", "</test:think>", "<alt:start>", "<alt:end>"]
|
||||
existing_tokens = set(tokenizer.get_vocab().keys())
|
||||
new_tokens = [
|
||||
token for token in test_tokens if token not in existing_tokens
|
||||
]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens)
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestBaseThinkingReasoningParserInit:
|
||||
"""
|
||||
Test initialization and basic properties of
|
||||
BaseThinkingReasoningParser.
|
||||
"""
|
||||
|
||||
def test_successful_initialization(self, test_tokenizer):
|
||||
"""Test successful initialization with valid tokens."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
assert parser.start_token == "<test:think>"
|
||||
assert parser.end_token == "</test:think>"
|
||||
assert parser.start_token_id is not None
|
||||
assert parser.end_token_id is not None
|
||||
|
||||
def test_initialization_with_missing_tokenizer(self):
|
||||
"""Test that initialization fails without tokenizer."""
|
||||
with pytest.raises(ValueError, match="model tokenizer must be passed"):
|
||||
TestThinkingReasoningParser(None)
|
||||
|
||||
def test_initialization_with_missing_tokens(self, test_tokenizer):
|
||||
"""Test that initialization fails when tokens are not in vocabulary."""
|
||||
|
||||
# Create a parser with tokens not in vocabulary
|
||||
class MissingTokenParser(BaseThinkingReasoningParser):
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
return "<missing:start>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
return "<missing:end>"
|
||||
|
||||
with pytest.raises(RuntimeError,
|
||||
match="could not locate think start/end tokens"):
|
||||
MissingTokenParser(test_tokenizer)
|
||||
|
||||
def test_initialization_with_empty_tokens(self, test_tokenizer):
|
||||
"""Test that initialization fails with empty token strings."""
|
||||
|
||||
class EmptyTokenParser(BaseThinkingReasoningParser):
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
return ""
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="start_token and end_token must be defined"):
|
||||
EmptyTokenParser(test_tokenizer)
|
||||
|
||||
|
||||
class TestBaseThinkingReasoningParserMethods:
|
||||
"""Test the methods of BaseThinkingReasoningParser."""
|
||||
|
||||
def test_is_reasoning_end(self, test_tokenizer):
|
||||
"""Test the is_reasoning_end method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
end_token_id = parser.end_token_id
|
||||
|
||||
# Test with end token present
|
||||
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
|
||||
|
||||
# Test without end token
|
||||
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
|
||||
|
||||
# Test with empty list
|
||||
assert parser.is_reasoning_end([]) is False
|
||||
|
||||
def test_extract_content_ids(self, test_tokenizer):
|
||||
"""Test the extract_content_ids method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
end_token_id = parser.end_token_id
|
||||
|
||||
# Test with end token in the middle
|
||||
input_ids = [1, 2, end_token_id, 4, 5]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == [4, 5]
|
||||
|
||||
# Test with end token at the end
|
||||
input_ids = [1, 2, 3, end_token_id]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == []
|
||||
|
||||
# Test without end token
|
||||
input_ids = [1, 2, 3, 4]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == []
|
||||
|
||||
# Test with end token as last element (should not extract)
|
||||
input_ids = [1, 2, 3, end_token_id]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == []
|
||||
|
||||
|
||||
class TestBaseThinkingReasoningParserExtraction:
|
||||
"""Test reasoning content extraction methods."""
|
||||
|
||||
def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer):
|
||||
"""Test extraction when both start and end tokens are present."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||
|
||||
model_output = ("<test:think>This is reasoning"
|
||||
"</test:think>This is content")
|
||||
reasoning, content = parser.extract_reasoning_content(
|
||||
model_output, request)
|
||||
|
||||
assert reasoning == "This is reasoning"
|
||||
assert content == "This is content"
|
||||
|
||||
def test_extract_reasoning_content_only_end_token(self, test_tokenizer):
|
||||
"""Test extraction when only end token is present."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||
|
||||
model_output = ("This is reasoning</test:think>This is content")
|
||||
reasoning, content = parser.extract_reasoning_content(
|
||||
model_output, request)
|
||||
|
||||
assert reasoning == "This is reasoning"
|
||||
assert content == "This is content"
|
||||
|
||||
def test_extract_reasoning_content_no_end_token(self, test_tokenizer):
|
||||
"""Test extraction when no end token is present."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||
|
||||
model_output = "This is just content"
|
||||
reasoning, content = parser.extract_reasoning_content(
|
||||
model_output, request)
|
||||
|
||||
assert reasoning == "This is just content"
|
||||
assert content is None
|
||||
|
||||
def test_extract_reasoning_content_empty_output(self, test_tokenizer):
|
||||
"""Test extraction with empty output."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||
|
||||
model_output = ""
|
||||
reasoning, content = parser.extract_reasoning_content(
|
||||
model_output, request)
|
||||
|
||||
assert reasoning == ""
|
||||
assert content is None
|
||||
|
||||
def test_extract_reasoning_content_only_tokens(self, test_tokenizer):
|
||||
"""Test extraction with only tokens and no content."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
request = ChatCompletionRequest(messages=[], model="test-model")
|
||||
|
||||
model_output = ("<test:think></test:think>")
|
||||
reasoning, content = parser.extract_reasoning_content(
|
||||
model_output, request)
|
||||
|
||||
assert reasoning == ""
|
||||
assert content is None
|
||||
|
||||
|
||||
class TestBaseThinkingReasoningParserStreaming:
|
||||
"""Test streaming functionality of BaseThinkingReasoningParser."""
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_simple_reasoning_extraction(self, test_tokenizer, streaming):
|
||||
"""
|
||||
Test basic reasoning extraction in both
|
||||
streaming and non-streaming modes.
|
||||
"""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
model_output = [
|
||||
"<test:think>", "Some ", "reasoning ", "content", "</test:think>",
|
||||
"Final ", "answer"
|
||||
]
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == "Some reasoning content"
|
||||
assert content == "Final answer"
|
||||
|
||||
def test_streaming_with_incremental_deltas(self, test_tokenizer):
|
||||
"""Test streaming processing with small incremental deltas."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
deltas = [
|
||||
"<test:think>",
|
||||
"Some ",
|
||||
"reasoning ",
|
||||
"content",
|
||||
"</test:think>",
|
||||
"Final ",
|
||||
"answer",
|
||||
]
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
deltas,
|
||||
streaming=True)
|
||||
|
||||
assert reasoning == "Some reasoning content"
|
||||
assert content == "Final answer"
|
||||
|
||||
def test_streaming_with_start_token(self, test_tokenizer):
|
||||
"""Test streaming with start token included."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
deltas = [
|
||||
"<test:think>",
|
||||
"Some ",
|
||||
"reasoning",
|
||||
"</test:think>",
|
||||
"Answer",
|
||||
]
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
deltas,
|
||||
streaming=True)
|
||||
|
||||
assert reasoning == "Some reasoning"
|
||||
assert content == "Answer"
|
||||
|
||||
def test_streaming_no_end_token(self, test_tokenizer):
|
||||
"""Test streaming when no end token is encountered."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
deltas = [
|
||||
"<test:think>",
|
||||
"Some ",
|
||||
"reasoning ",
|
||||
"without ",
|
||||
"end",
|
||||
]
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
deltas,
|
||||
streaming=True)
|
||||
|
||||
assert reasoning == "Some reasoning without end"
|
||||
assert content is None
|
||||
|
||||
def test_streaming_only_end_token(self, test_tokenizer):
|
||||
"""Test streaming when only end token appears."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
deltas = [
|
||||
"<test:think>",
|
||||
"Reasoning ",
|
||||
"content",
|
||||
"</test:think>",
|
||||
"Final",
|
||||
]
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
deltas,
|
||||
streaming=True)
|
||||
|
||||
assert reasoning == "Reasoning content"
|
||||
assert content == "Final"
|
||||
|
||||
|
||||
class TestBaseThinkingReasoningParserMultipleImplementations:
|
||||
"""
|
||||
Test that multiple implementations of
|
||||
BaseThinkingReasoningParser work correctly.
|
||||
"""
|
||||
|
||||
def test_different_token_implementations(self, test_tokenizer):
|
||||
"""
|
||||
Test that different implementations
|
||||
with different tokens work independently.
|
||||
"""
|
||||
parser1 = TestThinkingReasoningParser(test_tokenizer)
|
||||
parser2 = TestThinkingReasoningParserAlt(test_tokenizer)
|
||||
|
||||
# Test parser1
|
||||
model_output1 = ("Reasoning1</test:think>Content1")
|
||||
reasoning1, content1 = run_reasoning_extraction(
|
||||
parser1, [model_output1])
|
||||
assert reasoning1 == "Reasoning1"
|
||||
assert content1 == "Content1"
|
||||
|
||||
# Test parser2
|
||||
model_output2 = "Reasoning2<alt:end>Content2"
|
||||
reasoning2, content2 = run_reasoning_extraction(
|
||||
parser2, [model_output2])
|
||||
assert reasoning2 == "Reasoning2"
|
||||
assert content2 == "Content2"
|
||||
|
||||
# Verify tokens are different
|
||||
assert parser1.start_token != parser2.start_token
|
||||
assert parser1.end_token != parser2.end_token
|
||||
assert parser1.start_token_id != parser2.start_token_id
|
||||
assert parser1.end_token_id != parser2.end_token_id
|
||||
|
||||
|
||||
class TestBaseThinkingReasoningParserEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_multiple_end_tokens(self, test_tokenizer):
|
||||
"""Test behavior with multiple end tokens."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
model_output = ("First</test:think>Middle</test:think>Last")
|
||||
reasoning, content = run_reasoning_extraction(parser, [model_output])
|
||||
|
||||
# Should stop at first end token
|
||||
assert reasoning == "First"
|
||||
assert content == "Middle</test:think>Last"
|
||||
|
||||
def test_nested_tokens(self, test_tokenizer):
|
||||
"""Test behavior with nested-like token patterns."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
model_output = ("<test:think>Outer"
|
||||
"<test:think>Inner</test:think>Content")
|
||||
reasoning, content = run_reasoning_extraction(parser, [model_output])
|
||||
|
||||
# Should process normally, start from first start token
|
||||
assert reasoning == "Outer<test:think>Inner"
|
||||
assert content == "Content"
|
||||
|
||||
def test_malformed_tokens(self, test_tokenizer):
|
||||
"""Test behavior with malformed token-like strings."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
model_output = ("<test:thinking>Not a real token"
|
||||
"</test:thinking>Content")
|
||||
reasoning, content = run_reasoning_extraction(parser, [model_output])
|
||||
|
||||
# Should treat as regular content since tokens don't match exactly
|
||||
assert reasoning == ("<test:thinking>Not a real token"
|
||||
"</test:thinking>Content")
|
||||
assert content is None
|
||||
237
tests/reasoning/test_seedoss_reasoning_parser.py
Normal file
237
tests/reasoning/test_seedoss_reasoning_parser.py
Normal file
@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "seed_oss"
|
||||
start_token = "<seed:think>"
|
||||
end_token = "</seed:think>"
|
||||
|
||||
# Use a test model that contains our custom tokens
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def seedoss_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
# Add custom SeedOSS tokens if they don't exist
|
||||
if start_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([start_token, end_token])
|
||||
return tokenizer
|
||||
|
||||
|
||||
SIMPLE_REASONING: dict[str, Any] = {
|
||||
"output": "This is a reasoning section</seed:think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING: dict[str, Any] = {
|
||||
"output": "This is a reasoning section</seed:think>",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_CONTENT: dict[str, Any] = {
|
||||
"output": "This is content",
|
||||
"reasoning_content": "This is content",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING: dict[str, Any] = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES: dict[str, Any] = {
|
||||
"output": "This\nThat</seed:think>This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
WITH_START_TOKEN: dict[str, Any] = {
|
||||
"output": ("<seed:think>This is a reasoning section"
|
||||
"</seed:think>This is the rest"),
|
||||
"reasoning_content":
|
||||
"This is a reasoning section",
|
||||
"content":
|
||||
"This is the rest",
|
||||
"is_reasoning_end":
|
||||
True,
|
||||
}
|
||||
ONLY_END_TOKEN: dict[str, Any] = {
|
||||
"output": "Some reasoning</seed:think>This is the rest",
|
||||
"reasoning_content": "Some reasoning",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_TOKENS: dict[str, Any] = {
|
||||
"output": "This is just content without any reasoning tokens",
|
||||
"reasoning_content": "This is just content without any reasoning tokens",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
|
||||
def test_seedoss_reasoning_parser_creation(seedoss_tokenizer):
|
||||
"""Test that the SeedOSS reasoning parser can be created and registered."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
assert isinstance(parser, ReasoningParser)
|
||||
assert parser.start_token == start_token
|
||||
assert parser.end_token == end_token
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_simple_reasoning(seedoss_tokenizer, streaming):
|
||||
"""Test basic reasoning extraction with both tokens."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == SIMPLE_REASONING["reasoning_content"]
|
||||
assert content == SIMPLE_REASONING["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_complete_reasoning(seedoss_tokenizer, streaming):
|
||||
"""Test reasoning extraction when there's no content after reasoning."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == COMPLETE_REASONING["reasoning_content"]
|
||||
assert content == COMPLETE_REASONING["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_content(seedoss_tokenizer, streaming):
|
||||
"""Test when there's no end token - everything is reasoning content."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, NO_CONTENT["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == NO_CONTENT["reasoning_content"]
|
||||
assert content == NO_CONTENT["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_multiple_lines(seedoss_tokenizer, streaming):
|
||||
"""Test reasoning extraction with multiline content."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == MULTIPLE_LINES["reasoning_content"]
|
||||
assert content == MULTIPLE_LINES["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_with_start_token(seedoss_tokenizer, streaming):
|
||||
"""Test reasoning extraction with both start and end tokens."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == WITH_START_TOKEN["reasoning_content"]
|
||||
assert content == WITH_START_TOKEN["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_only_end_token(seedoss_tokenizer, streaming):
|
||||
"""
|
||||
Test reasoning extraction with only end token
|
||||
(SeedOSS typical behavior).
|
||||
"""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == ONLY_END_TOKEN["reasoning_content"]
|
||||
assert content == ONLY_END_TOKEN["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tokens(seedoss_tokenizer, streaming):
|
||||
"""Test when there are no reasoning tokens at all."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, [cast(str, NO_TOKENS["output"])], streaming=streaming)
|
||||
|
||||
assert reasoning == NO_TOKENS["reasoning_content"]
|
||||
assert content == NO_TOKENS["content"]
|
||||
|
||||
|
||||
def test_is_reasoning_end(seedoss_tokenizer):
|
||||
"""Test the is_reasoning_end method."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
# Test with end token present
|
||||
end_token_id = parser.end_token_id
|
||||
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
|
||||
|
||||
# Test without end token
|
||||
assert parser.is_reasoning_end([1, 2, 3, 4]) is False
|
||||
|
||||
|
||||
def test_extract_content_ids(seedoss_tokenizer):
|
||||
"""Test the extract_content_ids method."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
end_token_id = parser.end_token_id
|
||||
|
||||
# Test with end token in the middle
|
||||
input_ids = [1, 2, end_token_id, 4, 5]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == [4, 5]
|
||||
|
||||
# Test with end token at the end
|
||||
input_ids = [1, 2, 3, end_token_id]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == []
|
||||
|
||||
# Test without end token
|
||||
input_ids = [1, 2, 3, 4]
|
||||
content_ids = parser.extract_content_ids(input_ids)
|
||||
assert content_ids == []
|
||||
|
||||
|
||||
def test_streaming_delta_processing(seedoss_tokenizer):
|
||||
"""Test streaming processing with small deltas."""
|
||||
parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name)
|
||||
parser = parser_cls(seedoss_tokenizer)
|
||||
|
||||
# Test streaming with incremental tokens
|
||||
deltas = [
|
||||
"Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer"
|
||||
]
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
deltas,
|
||||
streaming=True)
|
||||
|
||||
assert reasoning == "Some reasoning content"
|
||||
assert content == "Final answer"
|
||||
@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -388,3 +390,108 @@ def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
|
||||
else:
|
||||
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
|
||||
assert actual_max_len == expected_max_len
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Simple mock object for testing maybe_pull_model_tokenizer_for_runai"""
|
||||
|
||||
def __init__(self, model: str, tokenizer: str):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.model_weights = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("s3_url", [
|
||||
"s3://example-bucket-1/model/",
|
||||
"s3://example-bucket-2/model/",
|
||||
])
|
||||
@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files')
|
||||
def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url):
|
||||
"""Test that S3 URLs create deterministic local directories for model and
|
||||
tokenizer."""
|
||||
# Mock pull_files to avoid actually downloading files during tests
|
||||
mock_pull_files.return_value = None
|
||||
|
||||
# Create first mock and run the method
|
||||
config1 = MockConfig(model=s3_url, tokenizer=s3_url)
|
||||
ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url)
|
||||
|
||||
# Check that model and tokenizer point to existing directories
|
||||
assert os.path.exists(
|
||||
config1.model), f"Model directory does not exist: {config1.model}"
|
||||
assert os.path.isdir(
|
||||
config1.model), f"Model path is not a directory: {config1.model}"
|
||||
assert os.path.exists(
|
||||
config1.tokenizer
|
||||
), f"Tokenizer directory does not exist: {config1.tokenizer}"
|
||||
assert os.path.isdir(
|
||||
config1.tokenizer
|
||||
), f"Tokenizer path is not a directory: {config1.tokenizer}"
|
||||
|
||||
# Verify that the paths are different from the original S3 URL
|
||||
assert config1.model != s3_url, (
|
||||
"Model path should be converted to local directory")
|
||||
assert config1.tokenizer != s3_url, (
|
||||
"Tokenizer path should be converted to local directory")
|
||||
|
||||
# Store the original paths
|
||||
created_model_dir = config1.model
|
||||
create_tokenizer_dir = config1.tokenizer
|
||||
|
||||
# Create a new mock and run the method with the same S3 URL
|
||||
config2 = MockConfig(model=s3_url, tokenizer=s3_url)
|
||||
ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url)
|
||||
|
||||
# Check that the new directories exist
|
||||
assert os.path.exists(
|
||||
config2.model), f"Model directory does not exist: {config2.model}"
|
||||
assert os.path.isdir(
|
||||
config2.model), f"Model path is not a directory: {config2.model}"
|
||||
assert os.path.exists(
|
||||
config2.tokenizer
|
||||
), f"Tokenizer directory does not exist: {config2.tokenizer}"
|
||||
assert os.path.isdir(
|
||||
config2.tokenizer
|
||||
), f"Tokenizer path is not a directory: {config2.tokenizer}"
|
||||
|
||||
# Verify that the paths are deterministic (same as before)
|
||||
assert config2.model == created_model_dir, (
|
||||
f"Model paths are not deterministic. "
|
||||
f"Original: {created_model_dir}, New: {config2.model}")
|
||||
assert config2.tokenizer == create_tokenizer_dir, (
|
||||
f"Tokenizer paths are not deterministic. "
|
||||
f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}")
|
||||
|
||||
|
||||
@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files')
|
||||
def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||
"""Test that different S3 URLs create different local directories."""
|
||||
# Mock pull_files to avoid actually downloading files during tests
|
||||
mock_pull_files.return_value = None
|
||||
|
||||
s3_url1 = "s3://example-bucket-1/model/"
|
||||
s3_url2 = "s3://example-bucket-2/model/"
|
||||
|
||||
# Create mocks with different S3 URLs and run the method
|
||||
config1 = MockConfig(model=s3_url1, tokenizer=s3_url1)
|
||||
ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url1, s3_url1)
|
||||
|
||||
config2 = MockConfig(model=s3_url2, tokenizer=s3_url2)
|
||||
ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url2, s3_url2)
|
||||
|
||||
# Verify that different URLs produce different directories
|
||||
assert config1.model != config2.model, (
|
||||
f"Different S3 URLs should create different model directories. "
|
||||
f"URL1 model: {config1.model}, URL2 model: {config2.model}")
|
||||
assert config1.tokenizer != config2.tokenizer, (
|
||||
f"Different S3 URLs should create different tokenizer directories. "
|
||||
f"URL1 tokenizer: {config1.tokenizer}, "
|
||||
f"URL2 tokenizer: {config2.tokenizer}")
|
||||
|
||||
# Verify that both sets of directories exist
|
||||
assert os.path.exists(config1.model) and os.path.isdir(config1.model)
|
||||
assert os.path.exists(config1.tokenizer) and os.path.isdir(
|
||||
config1.tokenizer)
|
||||
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
|
||||
assert os.path.exists(config2.tokenizer) and os.path.isdir(
|
||||
config2.tokenizer)
|
||||
|
||||
54
tests/tool_use/test_deepseekv31_tool_parser.py
Normal file
54
tests/tool_use/test_deepseekv31_tool_parser.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.tool_parsers import DeepSeekV31ToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MODEL = "deepseek-ai/DeepSeek-V3.1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseekv31_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(deepseekv31_tokenizer):
|
||||
return DeepSeekV31ToolParser(deepseekv31_tokenizer)
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_tool(parser):
|
||||
model_output = (
|
||||
"normal text" + "<|tool▁calls▁begin|>" +
|
||||
"<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" +
|
||||
"<|tool▁calls▁end|>")
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "foo"
|
||||
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
|
||||
assert result.content == "normal text"
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_multiple_tools(parser):
|
||||
model_output = (
|
||||
"some prefix text" + "<|tool▁calls▁begin|>" +
|
||||
"<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" +
|
||||
"<|tool▁call▁begin|>bar<|tool▁sep|>{\"y\":2}<|tool▁call▁end|>" +
|
||||
"<|tool▁calls▁end|>" + " some suffix text")
|
||||
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 2
|
||||
|
||||
assert result.tool_calls[0].function.name == "foo"
|
||||
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
|
||||
|
||||
assert result.tool_calls[1].function.name == "bar"
|
||||
assert result.tool_calls[1].function.arguments == "{\"y\":2}"
|
||||
|
||||
# prefix is content
|
||||
assert result.content == "some prefix text"
|
||||
@ -70,7 +70,12 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
||||
assert extracted_info.content == "This is a test"
|
||||
|
||||
|
||||
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
||||
@pytest.mark.parametrize("tool_args", [
|
||||
'{"location": "Tokyo"}',
|
||||
'{\n"location": "Tokyo"\n}',
|
||||
])
|
||||
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
|
||||
tool_args):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(Role.USER,
|
||||
"What is the weather in Tokyo?"),
|
||||
@ -80,7 +85,7 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
tool_args).with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
])
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
@ -121,6 +126,17 @@ def test_extract_tool_calls_multiple_tools(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_user_location").with_content_type("json"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, '{"location": "Tokyo"}').with_channel(
|
||||
"commentary").with_recipient("functions.no_content_type"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel(
|
||||
"commentary").with_recipient("functions.not_json_no_content_type"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, '{}').with_channel("commentary").with_recipient(
|
||||
"functions.empty_args").with_content_type("json"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, '').with_channel("commentary").with_recipient(
|
||||
"functions.no_args").with_content_type("json"),
|
||||
])
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
@ -141,7 +157,63 @@ def test_extract_tool_calls_multiple_tools(
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_user_location",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="no_content_type",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="not_json_no_content_type",
|
||||
arguments="foo",
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="empty_args",
|
||||
arguments=json.dumps({}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="no_args",
|
||||
arguments="",
|
||||
))
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_content(
|
||||
openai_tool_parser,
|
||||
harmony_encoding,
|
||||
):
|
||||
final_content = "This tool call will get the weather."
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT,
|
||||
final_content).with_channel("final"),
|
||||
])
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
Role.ASSISTANT,
|
||||
)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content == final_content
|
||||
|
||||
@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||
max_model_len=256,
|
||||
max_seq_len_to_capture=256,
|
||||
max_num_seqs=8,
|
||||
tensor_parallel_size=tp,
|
||||
enable_lora=True,
|
||||
|
||||
@ -9,8 +9,9 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata)
|
||||
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
def apply_split_decodes_and_prefills(query_lens: list[int],
|
||||
decode_threshold: int,
|
||||
require_uniform: bool):
|
||||
"""Helper function to apply split_decodes_and_prefills and return
|
||||
the results."""
|
||||
device = torch.device("cpu")
|
||||
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
|
||||
common_metadata = create_common_attn_metadata(BatchSpec(
|
||||
seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device)
|
||||
return split_decodes_and_prefills(common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False))
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
query_lens = [1, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
assert num_decodes == 7
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == sum(query_lens)
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
assert num_prefill_tokens == sum(query_lens)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False))
|
||||
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
|
||||
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
|
||||
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True))
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
assert num_prefill_tokens == 0
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
query_lens = [2, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
assert num_decodes == 2
|
||||
assert num_prefills == 5
|
||||
assert num_decode_tokens == 4
|
||||
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
assert num_prefill_tokens == sum(query_lens)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
|
||||
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 6 # 2 + 2 + 2
|
||||
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
assert num_decodes == 1 # only the first 2 is taken as decode
|
||||
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
|
||||
assert num_decode_tokens == 2 # only the first 2
|
||||
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
|
||||
[
|
||||
|
||||
@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
get_block_hash, get_group_id,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, get_block_hash,
|
||||
get_group_id,
|
||||
get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash,
|
||||
make_block_hash_with_group_id)
|
||||
@ -138,7 +139,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -171,7 +172,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -207,7 +208,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([6], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([6], )
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@ -227,7 +228,9 @@ def test_prefill(hash_fn):
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([
|
||||
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
|
||||
], )
|
||||
|
||||
assert free_block_queue.num_free_blocks == 0
|
||||
assert (free_block_queue.fake_free_list_head.next_free_block
|
||||
@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
||||
8], [9, 10, 11, 12])
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
|
||||
5, 6, 7, 8
|
||||
], [9, 10, 11, 12])
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
||||
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
|
||||
manager.free(req1)
|
||||
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
manager.block_pool.cached_block_hash_to_block._cache)
|
||||
|
||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||
block_size, sha256)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||
hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(req.block_hashes) == 3
|
||||
@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block[
|
||||
manager.block_pool.cached_block_hash_to_block._cache[
|
||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||
hash_with_group_id]
|
||||
manager.free(req)
|
||||
@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 1 * block_size.
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
# Then it is cache miss as the two type of layers
|
||||
# have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
@ -406,7 +412,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
@ -441,7 +447,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -478,6 +484,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
@ -513,7 +520,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@ -558,7 +565,8 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
||||
# 5 full + 1 partial
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 6
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
@ -570,7 +578,7 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
# 10 - (6 + 3) == 1
|
||||
@ -592,7 +600,7 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([10], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([10], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
|
||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||
[0].block_id].block_hash is None
|
||||
@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req0, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req1, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
|
||||
@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
||||
blocks = manager.allocate_slots(req1, 10,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 3
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
|
||||
blocks = manager.allocate_slots(req2, 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||
@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Test that blocks that don't start from the beginning are cached correctly.
|
||||
# Test that blocks that don't start from the beginning are cached
|
||||
# correctly.
|
||||
blocks += [KVCacheBlock(block_id=2)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
|
||||
# Manually add all blocks to cached_blocks
|
||||
for block, block_hash in zip(pool.blocks, block_hashes):
|
||||
block.block_hash = block_hash
|
||||
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
|
||||
pool.cached_block_hash_to_block.insert(block_hash, block)
|
||||
|
||||
block0, block1, block2, block3 = pool.blocks
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block0.block_id: block0,
|
||||
block3.block_id: block3
|
||||
block3.block_id: block3,
|
||||
},
|
||||
block_hash1: {
|
||||
block1.block_id: block1
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash1: block1,
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block1
|
||||
pool._maybe_evict_cached_block(block1)
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block0.block_id: block0,
|
||||
block3.block_id: block3
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
||||
# also use the same hash
|
||||
pool._maybe_evict_cached_block(block0)
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block3.block_id: block3
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block2
|
||||
pool._maybe_evict_cached_block(block2)
|
||||
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
|
||||
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
|
||||
# Evict block3
|
||||
pool._maybe_evict_cached_block(block3)
|
||||
assert pool.cached_block_hash_to_block == {}
|
||||
assert pool.cached_block_hash_to_block._cache == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||
@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
|
||||
# Evict the first block in the request
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
|
||||
# there will be no matched prefix.
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
key1 = BlockHashWithGroupId(b"hash1")
|
||||
key2 = BlockHashWithGroupId(b"hash2")
|
||||
block0 = KVCacheBlock(0)
|
||||
block1 = KVCacheBlock(1)
|
||||
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
# key0 inserted
|
||||
cache.insert(key0, block0)
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
# key1 inserted
|
||||
cache.insert(key1, block1)
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# No block poped due to block_id mismatch
|
||||
assert cache.pop(key0, 100) is None
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# block poped with (key0, block ID 0)
|
||||
assert cache.pop(key0, 0) is block0
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# No block poped due to block_id mismatch
|
||||
assert cache.pop(key0, 1) is None
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# block poped with (key1, block ID 1)
|
||||
assert cache.pop(key1, 1) is block1
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
|
||||
|
||||
def test_block_lookup_cache_multi_blocks_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
key1 = BlockHashWithGroupId(b"hash1")
|
||||
block00 = KVCacheBlock(0)
|
||||
block01 = KVCacheBlock(1)
|
||||
block10 = KVCacheBlock(10)
|
||||
block11 = KVCacheBlock(11)
|
||||
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
|
||||
cache.insert(key0, block00)
|
||||
cache.insert(key0, block01)
|
||||
cache.insert(key1, block10)
|
||||
cache.insert(key1, block11)
|
||||
|
||||
assert cache.get_one_block(key0) is block00
|
||||
assert cache.pop(key0, 0) is block00
|
||||
assert cache.get_one_block(key0) is block01
|
||||
assert cache.pop(key0, 1) is block01
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.pop(key0, 2) is None
|
||||
|
||||
assert cache.get_one_block(key1) is block10
|
||||
assert cache.pop(key1, 10) is block10
|
||||
assert cache.get_one_block(key1) is block11
|
||||
assert cache.pop(key1, 11) is block11
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.pop(key1, 12) is None
|
||||
|
||||
@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
block_pool.cached_block_hash_to_block._cache.clear()
|
||||
|
||||
# Mock the block pool with the cached blocks
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
block_pool.cached_block_hash_to_block.insert(
|
||||
make_block_hash_with_group_id(block_hash, 0),
|
||||
block_pool.blocks[i + 10])
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hashes=block_hash_list,
|
||||
@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix():
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
block_pool.cached_block_hash_to_block._cache.clear()
|
||||
|
||||
# Mock the block pool with the cached blocks
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
block_pool.cached_block_hash_to_block.insert(
|
||||
make_block_hash_with_group_id(block_hash, 0),
|
||||
block_pool.blocks[i + 10])
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hashes=block_hash_list,
|
||||
|
||||
@ -81,16 +81,6 @@ class CarDescription(BaseModel):
|
||||
car_type: CarType
|
||||
|
||||
|
||||
def _load_json(s: str, backend: str) -> str:
|
||||
if backend != "xgrammar":
|
||||
return json.loads(s)
|
||||
|
||||
# xgrammar specific workarounds
|
||||
# https://github.com/mlc-ai/xgrammar/issues/286
|
||||
s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s)
|
||||
return json.loads(s)
|
||||
|
||||
|
||||
def test_guided_decoding_deprecated():
|
||||
with pytest.warns(DeprecationWarning,
|
||||
match="GuidedDecodingParams is deprecated.*"):
|
||||
@ -177,7 +167,12 @@ def test_structured_output(
|
||||
if backend != 'lm-format-enforcer':
|
||||
assert "\n" not in generated_text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
try:
|
||||
output_json = json.loads(generated_text)
|
||||
except json.JSONDecodeError as e:
|
||||
pytest.fail(
|
||||
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
|
||||
f"Schema: {sample_json_schema}\nError: {e}")
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
#
|
||||
@ -425,7 +420,12 @@ def test_structured_output(
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
try:
|
||||
output_json = json.loads(generated_text)
|
||||
except json.JSONDecodeError as e:
|
||||
pytest.fail(
|
||||
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
|
||||
f"Schema: {json_schema}\nError: {e}")
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
#
|
||||
@ -468,7 +468,12 @@ def test_structured_output(
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
try:
|
||||
output_json = json.loads(generated_text)
|
||||
except json.JSONDecodeError as e:
|
||||
pytest.fail(
|
||||
f"Invalid JSON from backend={backend}: {generated_text!r}\n"
|
||||
f"Schema: {json_schema}\nError: {e}")
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
if backend not in ["outlines", "lm-format-enforcer"]:
|
||||
|
||||
@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
|
||||
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
"or, being injured, not kill, except in",
|
||||
"without the heart, one can only see wrongly.",
|
||||
"but in rising every time we fall. - Nelson"
|
||||
"or kill a human being",
|
||||
]
|
||||
|
||||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:
|
||||
|
||||
174
tests/v1/worker/test_worker_memory_snapshot.py
Normal file
174
tests/v1/worker/test_worker_memory_snapshot.py
Normal file
@ -0,0 +1,174 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import tempfile
|
||||
from multiprocessing import Queue
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import MemorySnapshot
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
|
||||
# Global queue to track operation order across processes
|
||||
_QUEUE: Optional[Queue] = None
|
||||
|
||||
|
||||
def track_operation(operation: str, rank: int):
|
||||
"""Track when an operation happens and its rank."""
|
||||
if _QUEUE is not None:
|
||||
_QUEUE.put((operation, rank))
|
||||
|
||||
|
||||
def make_operation_tracker(operation_name: str, original_func):
|
||||
"""Create a mock function that tracks when an operation is called.
|
||||
|
||||
Args:
|
||||
operation_name: Name to use when tracking this operation
|
||||
original_func: The original function to wrap
|
||||
|
||||
Returns:
|
||||
A wrapper function that tracks the operation and calls the original
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
rank = int(os.environ.get("RANK", "-1"))
|
||||
track_operation(operation_name, rank)
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def worker_process(rank: int, world_size: int, distributed_init_method: str,
|
||||
queue: Queue, error_queue: Queue):
|
||||
"""Worker process that initializes a GPU worker with proper tracking."""
|
||||
global _QUEUE
|
||||
_QUEUE = queue
|
||||
|
||||
try:
|
||||
# Set environment variables
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Create vLLM config with small model
|
||||
vllm_config = EngineArgs(model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
load_format="dummy").create_engine_config()
|
||||
|
||||
# Create worker
|
||||
worker = Worker(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
# Get original functions before patching
|
||||
original_init_worker = init_worker_distributed_environment
|
||||
original_memory_snapshot_init = MemorySnapshot.__init__
|
||||
original_all_reduce = torch.distributed.all_reduce
|
||||
|
||||
# Apply minimal patches to track operation order
|
||||
init_patch = patch(
|
||||
'vllm.v1.worker.gpu_worker.init_worker_distributed_environment',
|
||||
side_effect=make_operation_tracker("init_distributed",
|
||||
original_init_worker))
|
||||
memory_patch = patch.object(
|
||||
MemorySnapshot, '__init__',
|
||||
make_operation_tracker("memory_snapshot",
|
||||
original_memory_snapshot_init))
|
||||
all_reduce_patch = patch('torch.distributed.all_reduce',
|
||||
side_effect=make_operation_tracker(
|
||||
"nccl_all_reduce", original_all_reduce))
|
||||
|
||||
with init_patch, memory_patch, all_reduce_patch:
|
||||
|
||||
# Initialize device (this is where we test the order)
|
||||
worker.init_device()
|
||||
|
||||
# Load model to ensure everything works
|
||||
worker.load_model()
|
||||
|
||||
# Signal success
|
||||
queue.put(("success", rank))
|
||||
|
||||
except Exception as e:
|
||||
error_queue.put((rank, str(e), type(e).__name__))
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs for tensor parallelism")
|
||||
def test_init_distributed_is_called_before_memory_snapshot():
|
||||
"""Test that distributed env is setup before memory snapshot.
|
||||
|
||||
This test makes sure during worker initialization, the initial memory
|
||||
snapshot is taken after distributed env is setup to include all the buffers
|
||||
allocated by distributed env.
|
||||
"""
|
||||
world_size = 2
|
||||
|
||||
# Create a temporary file for distributed init
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
distributed_init_method = f"file://{f.name}"
|
||||
|
||||
# Create queues for inter-process communication
|
||||
ctx = mp.get_context("spawn")
|
||||
operation_queue = ctx.Queue()
|
||||
error_queue = ctx.Queue()
|
||||
|
||||
# Start worker processes
|
||||
processes = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(target=worker_process,
|
||||
args=(rank, world_size, distributed_init_method,
|
||||
operation_queue, error_queue))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join(timeout=60) # 60 second timeout
|
||||
|
||||
# Check for errors
|
||||
errors = []
|
||||
while not error_queue.empty():
|
||||
rank, error_msg, error_type = error_queue.get()
|
||||
errors.append(f"Rank {rank}: {error_type}: {error_msg}")
|
||||
|
||||
if errors:
|
||||
pytest.fail("Worker processes failed:\n" + "\n".join(errors))
|
||||
|
||||
# Collect all operations from the queue
|
||||
operations = []
|
||||
while not operation_queue.empty():
|
||||
operations.append(operation_queue.get())
|
||||
|
||||
# Verify we got operations from both ranks
|
||||
print(f"Collected operations: {operations}")
|
||||
|
||||
# Check operations for each rank
|
||||
for rank in range(world_size):
|
||||
rank_ops = [op for op, r in operations if r == rank]
|
||||
print(f"\nRank {rank} operations: {rank_ops}")
|
||||
|
||||
# Raises ValueError if the operation is not found
|
||||
init_distributed = rank_ops.index("init_distributed")
|
||||
nccl_all_reduce = rank_ops.index("nccl_all_reduce")
|
||||
memory_snapshot = rank_ops.index("memory_snapshot")
|
||||
|
||||
# Verify order: init_distributed should happen before memory_snapshot
|
||||
assert init_distributed < nccl_all_reduce < memory_snapshot, (
|
||||
f"Rank {rank}: init_distributed (index {init_distributed}) "
|
||||
f"must happen before nccl_all_reduce (index {nccl_all_reduce}) "
|
||||
f"and memory_snapshot (index {memory_snapshot})")
|
||||
|
||||
# Clean up
|
||||
os.unlink(distributed_init_method.replace("file://", ""))
|
||||
@ -1,314 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
# Placeholder attention backend for models like Mamba and pooling models that
|
||||
# lack attention.
|
||||
|
||||
|
||||
class PlaceholderAttentionBackend(AttentionBackend):
|
||||
"""Placeholder backend for when no attention is needed."""
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "NO_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
||||
return PlaceholderAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
||||
return PlaceholderAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
||||
return PlaceholderAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, 1, 1, 1, 1)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int]
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Placeholder.
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=0,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
|
||||
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def prepare(self):
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
|
||||
# Some input builders such as ModelInputForCPUBuilder do not have the
|
||||
# "inter_data_list" attribute.
|
||||
# Let's check inter_data_list exists before we reference it.
|
||||
if hasattr(self.input_builder, "inter_data_list"):
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
if use_captured_graph:
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
|
||||
# Placeholders
|
||||
slot_mapping_tensor = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
return PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class PlaceholderAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return
|
||||
|
||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState):
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
max_decode_seq_len=self.runner.max_model_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState):
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
|
||||
attn_metadata.num_encoder_tokens = 0
|
||||
|
||||
def _add_additional_input_buffers_for_enc_dec_model(
|
||||
|
||||
@ -115,12 +115,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
is_attention_free = False
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
@ -185,7 +183,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
@ -578,9 +575,7 @@ def unified_attention_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
mutates_args=[],
|
||||
fake_impl=unified_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
@ -631,6 +626,5 @@ direct_register_custom_op(
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=tag_cudagraph_unsafe,
|
||||
)
|
||||
|
||||
@ -2,10 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@ -142,7 +142,6 @@ def get_attn_backend(
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
@ -156,7 +155,6 @@ def get_attn_backend(
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
@ -169,17 +167,10 @@ def _cached_get_attn_backend(
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
|
||||
@ -547,7 +547,6 @@ if flashinfer_comm is not None:
|
||||
"scale_out",
|
||||
],
|
||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
flashinfer_trtllm_fused_allreduce_norm = (
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
|
||||
|
||||
@ -551,8 +551,9 @@ def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = True
|
||||
config["coordinate_descent_tuning"] = True
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
@ -300,13 +301,13 @@ def _support_torch_compile(
|
||||
logger.debug(
|
||||
"enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
with patch.object(InliningInstructionTranslator, 'inline_call',
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config):
|
||||
with patch.object(
|
||||
InliningInstructionTranslator, "inline_call",
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
|
||||
return output
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses():
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(
|
||||
torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||
RowvLLMParameter
|
||||
]),
|
||||
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||
return_false)):
|
||||
yield
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
PrefixCachingHashAlgo)
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.device import Device, DeviceConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
|
||||
try_match_architecture_defaults)
|
||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
||||
from vllm.logger import init_logger
|
||||
@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol):
|
||||
...
|
||||
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError(
|
||||
"Failed to infer device type, please set "
|
||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||
"to turn on verbose logging to help debug the issue.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||
specified version. For example, if a previously deprecated metric has been
|
||||
hidden since the v0.7.0 release, you use
|
||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||
you migrate to new metrics. The metric is likely to be removed completely
|
||||
in an upcoming release."""
|
||||
|
||||
@cached_property
|
||||
def show_hidden_metrics(self) -> bool:
|
||||
"""Check if the hidden metrics should be shown."""
|
||||
if self.show_hidden_metrics_for_version is None:
|
||||
return False
|
||||
return version._prev_minor_version_was(
|
||||
self.show_hidden_metrics_for_version)
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||
|
||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||
set, it will collect detailed traces for the specified modules. This
|
||||
involves use of possibly costly and or blocking operations and hence might
|
||||
have a performance impact.
|
||||
|
||||
Note that collecting detailed timing information for each request can be
|
||||
expensive."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("model" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
@cached_property
|
||||
def collect_model_execute_time(self) -> bool:
|
||||
"""Whether to collect model execute time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("worker" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.collect_detailed_traces is not None
|
||||
and len(self.collect_detailed_traces) == 1
|
||||
and "," in self.collect_detailed_traces[0]):
|
||||
self._parse_collect_detailed_traces()
|
||||
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||
|
||||
def _parse_collect_detailed_traces(self):
|
||||
assert isinstance(self.collect_detailed_traces, list)
|
||||
self.collect_detailed_traces = cast(
|
||||
list[DetailedTraceModules],
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
@ -1009,37 +860,6 @@ def get_layers_from_vllm_config(
|
||||
}
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: Optional[int] = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return self.min_energy_split_window_size is not None
|
||||
|
||||
|
||||
def update_config(config: DataclassInstanceT,
|
||||
overrides: dict[str, Any]) -> DataclassInstanceT:
|
||||
processed_overrides = {}
|
||||
|
||||
74
vllm/config/device.py
Normal file
74
vllm/config/device.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError(
|
||||
"Failed to infer device type, please set "
|
||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||
"to turn on verbose logging to help debug the issue.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
@ -177,11 +177,6 @@ class ModelConfig:
|
||||
graph and always execute the model in eager mode. If False, we will use
|
||||
CUDA graph and eager execution in hybrid for maximal performance and
|
||||
flexibility."""
|
||||
max_seq_len_to_capture: int = 8192
|
||||
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
|
||||
length larger than this, we fall back to eager mode. Additionally for
|
||||
encoder-decoder models, if the sequence length of the encoder input is
|
||||
larger than this, we fall back to the eager mode."""
|
||||
max_logprobs: int = 20
|
||||
"""Maximum number of log probabilities to return when `logprobs` is
|
||||
specified in `SamplingParams`. The default value comes the default for the
|
||||
@ -699,11 +694,12 @@ class ModelConfig:
|
||||
model: Model name or path
|
||||
tokenizer: Tokenizer name or path
|
||||
"""
|
||||
|
||||
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
|
||||
return
|
||||
|
||||
if is_runai_obj_uri(model):
|
||||
object_storage_model = ObjectStorageModel()
|
||||
object_storage_model = ObjectStorageModel(url=model)
|
||||
object_storage_model.pull_files(
|
||||
model, allow_pattern=["*.model", "*.py", "*.json"])
|
||||
self.model_weights = model
|
||||
@ -722,7 +718,7 @@ class ModelConfig:
|
||||
|
||||
# Only download tokenizer if needed and not already handled
|
||||
if is_runai_obj_uri(tokenizer):
|
||||
object_storage_tokenizer = ObjectStorageModel()
|
||||
object_storage_tokenizer = ObjectStorageModel(url=tokenizer)
|
||||
object_storage_tokenizer.pull_files(model,
|
||||
ignore_pattern=[
|
||||
"*.pt", "*.safetensors",
|
||||
@ -1023,21 +1019,8 @@ class ModelConfig:
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
# The `max_seq_len_to_capture` was incorrectly
|
||||
# based on the encoder's input length (448)
|
||||
# but not the decoder's larger input length (1500).
|
||||
# This change ensures the CUDA Graph captures the correct,
|
||||
# larger sequence length, allowing it to work as intended.
|
||||
effective_max_seq_len = self.max_model_len
|
||||
if self.is_encoder_decoder:
|
||||
effective_max_seq_len = max(
|
||||
effective_max_seq_len,
|
||||
getattr(self.hf_config, "max_source_positions", 0))
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
effective_max_seq_len)
|
||||
# CUDAGraph capture not supported for encoder-decoder models on ROCm
|
||||
unsupported_rocm = self.is_encoder_decoder
|
||||
|
||||
if (unsupported_rocm and not self.enforce_eager
|
||||
and current_platform.is_rocm()):
|
||||
logger.warning(
|
||||
|
||||
99
vllm/config/observability.py
Normal file
99
vllm/config/observability.py
Normal file
@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||
specified version. For example, if a previously deprecated metric has been
|
||||
hidden since the v0.7.0 release, you use
|
||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||
you migrate to new metrics. The metric is likely to be removed completely
|
||||
in an upcoming release."""
|
||||
|
||||
@cached_property
|
||||
def show_hidden_metrics(self) -> bool:
|
||||
"""Check if the hidden metrics should be shown."""
|
||||
if self.show_hidden_metrics_for_version is None:
|
||||
return False
|
||||
return version._prev_minor_version_was(
|
||||
self.show_hidden_metrics_for_version)
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||
|
||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||
set, it will collect detailed traces for the specified modules. This
|
||||
involves use of possibly costly and or blocking operations and hence might
|
||||
have a performance impact.
|
||||
|
||||
Note that collecting detailed timing information for each request can be
|
||||
expensive."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("model" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
@cached_property
|
||||
def collect_model_execute_time(self) -> bool:
|
||||
"""Whether to collect model execute time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("worker" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.collect_detailed_traces is not None
|
||||
and len(self.collect_detailed_traces) == 1
|
||||
and "," in self.collect_detailed_traces[0]):
|
||||
self._parse_collect_detailed_traces()
|
||||
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||
|
||||
def _parse_collect_detailed_traces(self):
|
||||
assert isinstance(self.collect_detailed_traces, list)
|
||||
self.collect_detailed_traces = cast(
|
||||
list[DetailedTraceModules],
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
@ -285,8 +285,6 @@ class SpeculativeConfig:
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=self.target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
39
vllm/config/speech_to_text.py
Normal file
39
vllm/config/speech_to_text.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: Optional[int] = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return self.min_energy_split_window_size is not None
|
||||
@ -46,7 +46,6 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce_symmetric_with_copy",
|
||||
op_func=all_reduce_symmetric_with_copy_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||
)
|
||||
|
||||
|
||||
@ -392,7 +392,8 @@ class MessageQueue:
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.debug(
|
||||
("No available shared memory broadcast block found"
|
||||
" in %s second."),
|
||||
" in %s seconds. This typically happens when some"
|
||||
" processes are hanging."),
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
@ -455,7 +456,8 @@ class MessageQueue:
|
||||
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.debug(
|
||||
("No available shared memory broadcast block found"
|
||||
" in %s second."),
|
||||
" in %s seconds. This typically happens when some"
|
||||
" processes are hanging."),
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL,
|
||||
)
|
||||
n_warning += 1
|
||||
|
||||
@ -574,7 +574,6 @@ class NixlConnectorWorker:
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.use_mla)
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
|
||||
@ -149,29 +149,22 @@ def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
direct_register_custom_op(
|
||||
op_name="all_reduce",
|
||||
op_func=all_reduce,
|
||||
mutates_args=[],
|
||||
fake_impl=all_reduce_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="reduce_scatter",
|
||||
op_func=reduce_scatter,
|
||||
mutates_args=[],
|
||||
fake_impl=reduce_scatter_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="all_gather",
|
||||
op_func=all_gather,
|
||||
mutates_args=[],
|
||||
fake_impl=all_gather_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -373,7 +373,6 @@ class EngineArgs:
|
||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
@ -545,8 +544,6 @@ class EngineArgs:
|
||||
**model_kwargs["quantization"])
|
||||
model_group.add_argument("--enforce-eager",
|
||||
**model_kwargs["enforce_eager"])
|
||||
model_group.add_argument("--max-seq-len-to-capture",
|
||||
**model_kwargs["max_seq_len_to_capture"])
|
||||
model_group.add_argument("--max-logprobs",
|
||||
**model_kwargs["max_logprobs"])
|
||||
model_group.add_argument("--logprobs-mode",
|
||||
@ -1008,7 +1005,6 @@ class EngineArgs:
|
||||
max_model_len=self.max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.enforce_eager,
|
||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||
max_logprobs=self.max_logprobs,
|
||||
logprobs_mode=self.logprobs_mode,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
|
||||
@ -130,11 +130,6 @@ class LLM:
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
disable_custom_all_reduce: See
|
||||
[ParallelConfig][vllm.config.ParallelConfig].
|
||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||
@ -184,7 +179,6 @@ class LLM:
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: bool = False,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
@ -281,7 +275,6 @@ class LLM:
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
hf_token=hf_token,
|
||||
hf_overrides=hf_overrides,
|
||||
|
||||
@ -1186,6 +1186,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = None
|
||||
|
||||
if self.use_harmony:
|
||||
reasoning_content, content, _ = parse_chat_output(token_ids)
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
|
||||
if self.tool_parser is not None:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
# NOTE: We use token_ids for openai tool parser
|
||||
@ -1194,10 +1198,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request=request,
|
||||
token_ids=token_ids, # type: ignore
|
||||
)
|
||||
reasoning_content, content = None, tool_call_info.content
|
||||
if request.include_reasoning:
|
||||
reasoning_content, content, _ = parse_chat_output(
|
||||
token_ids)
|
||||
content = tool_call_info.content
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
@ -1205,10 +1206,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_calls=tool_call_info.tool_calls,
|
||||
)
|
||||
else:
|
||||
reasoning_content, content, _ = parse_chat_output(
|
||||
token_ids)
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
|
||||
@ -235,8 +235,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# Handle the previous response ID.
|
||||
prev_response_id = request.previous_response_id
|
||||
if prev_response_id is not None:
|
||||
if not prev_response_id.startswith("resp_"):
|
||||
return self._make_invalid_id_error(prev_response_id)
|
||||
async with self.response_store_lock:
|
||||
prev_response = self.response_store.get(prev_response_id)
|
||||
if prev_response is None:
|
||||
@ -924,9 +922,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
stream: Optional[bool],
|
||||
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[
|
||||
StreamingResponsesResponse, None]]:
|
||||
if not response_id.startswith("resp_"):
|
||||
return self._make_invalid_id_error(response_id)
|
||||
|
||||
async with self.response_store_lock:
|
||||
response = self.response_store.get(response_id)
|
||||
|
||||
@ -944,9 +939,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
response_id: str,
|
||||
) -> Union[ErrorResponse, ResponsesResponse]:
|
||||
if not response_id.startswith("resp_"):
|
||||
return self._make_invalid_id_error(response_id)
|
||||
|
||||
async with self.response_store_lock:
|
||||
response = self.response_store.get(response_id)
|
||||
if response is None:
|
||||
@ -972,13 +964,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
response_id)
|
||||
return response
|
||||
|
||||
def _make_invalid_id_error(self, response_id: str) -> ErrorResponse:
|
||||
return self.create_error_response(
|
||||
err_type="invalid_request_error",
|
||||
message=(f"Invalid 'response_id': '{response_id}'. "
|
||||
"Expected an ID that begins with 'resp'."),
|
||||
)
|
||||
|
||||
def _make_not_found_error(self, response_id: str) -> ErrorResponse:
|
||||
return self.create_error_response(
|
||||
err_type="invalid_request_error",
|
||||
|
||||
@ -39,7 +39,7 @@ class DeepSeekV31ToolParser(ToolParser):
|
||||
self.tool_call_end_token: str = "<|tool▁call▁end|>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)<|tool▁call▁end|>"
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -12,10 +13,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("openai")
|
||||
class OpenAIToolParser(ToolParser):
|
||||
@ -40,17 +44,33 @@ class OpenAIToolParser(ToolParser):
|
||||
|
||||
if len(parser.messages) > 0:
|
||||
for msg in parser.messages:
|
||||
if len(msg.content) < 1:
|
||||
continue
|
||||
msg_text = msg.content[0].text
|
||||
if msg.recipient and msg.recipient.startswith("functions."):
|
||||
# If no content-type is given assume JSON, as that's the
|
||||
# most common case with gpt-oss models.
|
||||
if not msg.content_type or "json" in msg.content_type:
|
||||
# load and dump the JSON text to check validity and
|
||||
# remove any extra newlines or other odd formatting
|
||||
try:
|
||||
tool_args = json.dumps(json.loads(msg_text))
|
||||
except json.JSONDecodeError:
|
||||
logger.exception(
|
||||
"Error decoding JSON tool call from response.")
|
||||
tool_args = msg_text
|
||||
else:
|
||||
tool_args = msg_text
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=msg.recipient.split("functions.")[1],
|
||||
arguments=msg.content[0].text,
|
||||
arguments=tool_args,
|
||||
),
|
||||
))
|
||||
elif msg.channel == "final":
|
||||
final_content = msg.content[0].text
|
||||
final_content = msg_text
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
|
||||
24
vllm/envs.py
24
vllm/envs.py
@ -119,7 +119,7 @@ if TYPE_CHECKING:
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
|
||||
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
@ -187,12 +187,15 @@ if TYPE_CHECKING:
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
|
||||
VLLM_DBO_COMM_SMS: int = 20
|
||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
|
||||
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
|
||||
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||
|
||||
@ -1014,7 +1017,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# max number splits for cuda graph decode
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
|
||||
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"16")),
|
||||
"32")),
|
||||
|
||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||
# it allows ray to schedule multiple actors on a single GPU,
|
||||
@ -1385,6 +1388,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||
|
||||
# Add optional nvtx scopes for profiling, disable to avoid overheads
|
||||
"VLLM_NVTX_SCOPES_FOR_PROFILING":
|
||||
lambda: bool(int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0"))),
|
||||
|
||||
# Represent block hashes in KV cache events as 64-bit integers instead of
|
||||
# raw bytes. Defaults to True for backward compatibility.
|
||||
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
||||
@ -1413,6 +1420,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"code_interpreter",
|
||||
"web_search_preview"]),
|
||||
|
||||
# Enable max_autotune & coordinate_descent_tuning in inductor_config
|
||||
# to compile static shapes passed from compile_sizes in compilation_config
|
||||
# If set to 1, enable max_autotune; By default, this is enabled (1)
|
||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))),
|
||||
# If set to 1, enable coordinate_descent_tuning;
|
||||
# By default, this is enabled (1)
|
||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||
"1"))),
|
||||
|
||||
# Flag to enable NCCL symmetric memory allocation and registration
|
||||
"VLLM_USE_NCCL_SYMM_MEM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
||||
@ -1513,6 +1531,8 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
||||
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
# if this goes out of sync with environment_variables,
|
||||
|
||||
@ -7,7 +7,6 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||
build_explicit_enc_dec_prompt, embeds_inputs,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import InputContext, InputProcessingContext
|
||||
|
||||
__all__ = [
|
||||
"DataPrompt",
|
||||
@ -28,6 +27,4 @@ __all__ = [
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
"InputContext",
|
||||
"InputProcessingContext",
|
||||
]
|
||||
|
||||
@ -1,186 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.utils import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
else:
|
||||
ModelConfig = Any
|
||||
AnyTokenizer = Any
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
"""
|
||||
Contains information about the model which may be used to
|
||||
modify the inputs.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig
|
||||
"""The configuration of the model."""
|
||||
|
||||
def get_hf_config(
|
||||
self,
|
||||
typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
|
||||
/,
|
||||
) -> _C:
|
||||
"""
|
||||
Get the HuggingFace configuration
|
||||
(`transformers.PretrainedConfig`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the configuration is not of the specified type.
|
||||
"""
|
||||
hf_config = self.model_config.hf_config
|
||||
if not isinstance(hf_config, typ):
|
||||
raise TypeError("Invalid type of HuggingFace config. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(hf_config)}")
|
||||
|
||||
return hf_config
|
||||
|
||||
def get_hf_image_processor_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the HuggingFace image processor configuration of the model.
|
||||
"""
|
||||
return self.model_config.hf_image_processor_config
|
||||
|
||||
def get_mm_config(self):
|
||||
"""
|
||||
Get the multimodal config of the model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model is not a multimodal model.
|
||||
"""
|
||||
mm_config = self.model_config.multimodal_config
|
||||
if mm_config is None:
|
||||
raise RuntimeError("Not a multimodal model")
|
||||
|
||||
return mm_config
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> _P:
|
||||
"""
|
||||
Get the HuggingFace processor
|
||||
(`transformers.ProcessorMixin`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the processor is not of the specified type.
|
||||
"""
|
||||
return cached_processor_from_config(
|
||||
self.model_config,
|
||||
processor_cls=typ,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def init_processor(
|
||||
self,
|
||||
typ: type[_T],
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> _T:
|
||||
"""
|
||||
Initialize a HuggingFace-like processor class, merging the
|
||||
keyword arguments with those in the model's configuration.
|
||||
"""
|
||||
mm_config = self.model_config.get_multimodal_config()
|
||||
base_kwargs = mm_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
return typ(**merged_kwargs)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputProcessingContext(InputContext):
|
||||
tokenizer: AnyTokenizer
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> _P:
|
||||
return super().get_hf_processor(
|
||||
typ,
|
||||
tokenizer=self.tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
data: Mapping[str, object],
|
||||
kwargs: Mapping[str, object] = {},
|
||||
) -> Union[BatchFeature, JSONTree]:
|
||||
"""
|
||||
Call `hf_processor` on the prompt `data`
|
||||
(text, image, audio...) with configurable options `kwargs`.
|
||||
"""
|
||||
assert callable(hf_processor)
|
||||
|
||||
mm_config = self.model_config.get_multimodal_config()
|
||||
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
|
||||
|
||||
allowed_kwargs = get_allowed_kwarg_only_overrides(
|
||||
hf_processor,
|
||||
merged_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
def maybe_cast_dtype(x):
|
||||
# This mimics the behavior of transformers.BatchFeature
|
||||
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
||||
return x.to(dtype=self.model_config.dtype)
|
||||
return x
|
||||
|
||||
try:
|
||||
output = hf_processor(**data,
|
||||
**allowed_kwargs,
|
||||
return_tensors="pt")
|
||||
# this emulates output.to(dtype=self.model_config.dtype)
|
||||
if isinstance(output, BatchFeature):
|
||||
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
|
||||
return BatchFeature(cast_output)
|
||||
|
||||
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
||||
|
||||
logger.warning_once(
|
||||
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
||||
"Make sure to match the behaviour of `ProcessorMixin` when "
|
||||
"implementing custom processors.")
|
||||
return cast_output
|
||||
|
||||
except Exception as exc:
|
||||
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={allowed_kwargs}")
|
||||
|
||||
raise ValueError(msg) from exc
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -283,7 +282,6 @@ try:
|
||||
op_func=_lora_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_expand_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
lora_expand = torch.ops.vllm.lora_expand
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -237,7 +236,6 @@ try:
|
||||
op_func=_lora_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_shrink_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
lora_shrink = torch.ops.vllm.lora_shrink
|
||||
|
||||
|
||||
@ -40,8 +40,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.constexpr, # num of sequences
|
||||
T: tl.constexpr, # num of tokens
|
||||
N: tl.int64, # num of sequences
|
||||
T: tl.int64, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
|
||||
@ -0,0 +1,123 @@
|
||||
{
|
||||
"triton_version": "3.4.0",
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -98,13 +98,16 @@ def select_experts(
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
elif custom_routing_function is None:
|
||||
assert scoring_func == "softmax"
|
||||
topk_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
|
||||
topk_logit_vals, topk_idx = torch.topk(router_logits,
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
if renormalize:
|
||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids.to(torch.int32)
|
||||
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
|
||||
else:
|
||||
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
|
||||
topk_vals = (topk_logit_vals - logZ).exp()
|
||||
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
|
||||
else:
|
||||
return custom_routing_function(hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
@ -92,7 +92,6 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
mutates_args=[],
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
@ -235,6 +235,5 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
|
||||
direct_register_custom_op(
|
||||
op_name="fused_marlin_moe",
|
||||
op_func=fused_marlin_moe,
|
||||
mutates_args=[],
|
||||
fake_impl=fused_marlin_moe_fake,
|
||||
)
|
||||
|
||||
@ -1256,7 +1256,6 @@ def outplace_fused_experts_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_fused_experts",
|
||||
op_func=outplace_fused_experts,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
tags=(() if is_torch_equal_or_newer("2.7.0") else
|
||||
(torch.Tag.needs_fixed_stride_order, )),
|
||||
|
||||
@ -23,7 +23,7 @@ if has_triton_kernels():
|
||||
from triton_kernels.routing import (RoutingData, routing,
|
||||
routing_from_bitmatrix)
|
||||
from triton_kernels.tensor import Bitmatrix
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible. Error: %s", e)
|
||||
|
||||
@ -69,8 +69,6 @@ else:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk as grouped_topk)
|
||||
elif current_platform.is_cpu():
|
||||
pass
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
if current_platform.is_tpu():
|
||||
@ -2040,7 +2038,6 @@ direct_register_custom_op(
|
||||
op_func=moe_forward,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
@ -2071,7 +2068,6 @@ direct_register_custom_op(
|
||||
op_func=moe_forward_shared,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=moe_forward_shared_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
@ -223,17 +223,13 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_moe",
|
||||
op_func=rocm_aiter_fused_moe_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_fused_moe_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -241,7 +237,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_topk_softmax_impl,
|
||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||
fake_impl=rocm_aiter_topk_softmax_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -249,7 +244,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -257,7 +251,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -103,17 +103,13 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=rocm_aiter_rms_norm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_rms_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"UnquantizedLinearMethod",
|
||||
"CompressedTensorsLinearMethod",
|
||||
"CompressedTensorsLinearTransformMethod",
|
||||
"BitBLASLinearMethod",
|
||||
@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
try:
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
weight_loader = extra_weight_attrs.pop("weight_loader")
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"Failed to create unquantized linear weights. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"the weight.") from e
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
@ -401,5 +400,4 @@ direct_register_custom_op(
|
||||
op_func=linear_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=linear_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -27,7 +27,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
@ -464,5 +463,4 @@ direct_register_custom_op(
|
||||
op_func=mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -34,7 +34,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
@ -765,5 +764,4 @@ direct_register_custom_op(
|
||||
op_func=mamba_mixer2,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer2_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -17,7 +17,6 @@ from .mamba_ssm import softplus
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_H': 1}),
|
||||
triton.Config({'BLOCK_SIZE_H': 2}),
|
||||
triton.Config({'BLOCK_SIZE_H': 4}),
|
||||
triton.Config({'BLOCK_SIZE_H': 8}),
|
||||
|
||||
@ -21,7 +21,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionMetadata)
|
||||
@ -251,5 +250,4 @@ direct_register_custom_op(
|
||||
op_func=short_conv,
|
||||
mutates_args=["output"],
|
||||
fake_impl=short_conv_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
@ -63,7 +64,7 @@ __all__ = [
|
||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsW8A8Int8MoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsW4A4MoeMethod"
|
||||
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod"
|
||||
]
|
||||
|
||||
|
||||
@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
elif quant_config._is_dynamic_token_w4a8_int(weight_quant,
|
||||
input_quant):
|
||||
return CompressedTensorsW4A8Int8MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"""
|
||||
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
|
||||
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
|
||||
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
|
||||
- Bias: Same data type as original weights
|
||||
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
|
||||
quantized inside the kernel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Validate scheme: weights=W4 (channel or group),
|
||||
# activations=dynamic TOKEN (A8)
|
||||
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
aq = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
|
||||
# Must be dynamic per-token activations
|
||||
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
|
||||
raise ValueError(
|
||||
"W4A8-int MoE needs dynamic per-token activation quantization."
|
||||
)
|
||||
|
||||
# Weight can be channel-wise (group_size=None) or group-wise
|
||||
self.group_size = wq.group_size if (wq.group_size is not None) else -1
|
||||
if wq.num_bits != 4:
|
||||
raise ValueError(
|
||||
"This method only supports 4-bit weights (num_bits=4).")
|
||||
|
||||
# CPU only
|
||||
if not current_platform.is_cpu():
|
||||
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
|
||||
|
||||
# Arm: check _dyn ops availability
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||
except AttributeError as err:
|
||||
raise RuntimeError(
|
||||
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
|
||||
install a newer build.""") from err
|
||||
self.static_input_scales = False # always dynamic per token
|
||||
|
||||
# ---- parameter creation ----
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Shapes per local rank (TP/EP):
|
||||
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
|
||||
# w2 : [E, H, I_local] int8
|
||||
# Scales:
|
||||
# channel-wise: group_size=-1 -> per-output-row, single scale per row
|
||||
# group-wise : group_size=g ->
|
||||
# per-output-row, (in_features/g) scales
|
||||
|
||||
E = num_experts
|
||||
H = hidden_size
|
||||
IN = intermediate_size_per_partition
|
||||
g = self.group_size
|
||||
|
||||
# Per-row scale columns
|
||||
def _n_scale_cols(in_features: int) -> int:
|
||||
return 1 if g == -1 else (in_features // g)
|
||||
|
||||
# Register unpacked int4-as-int8 weights the loader will fill.
|
||||
w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w13, extra_weight_attrs)
|
||||
layer.register_parameter("w13_weight", w13)
|
||||
|
||||
w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w2, extra_weight_attrs)
|
||||
layer.register_parameter("w2_weight", w2)
|
||||
|
||||
# Register scales
|
||||
# KleidiAI groupwise kernels accepts float32 scales
|
||||
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||
|
||||
w13_s = torch.nn.Parameter(torch.ones(E,
|
||||
2 * IN,
|
||||
_n_scale_cols(H),
|
||||
dtype=scale_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w13_s, {
|
||||
"quant_method": "channel" if g == -1 else "group",
|
||||
**extra_weight_attrs
|
||||
})
|
||||
layer.register_parameter("w13_weight_scale", w13_s)
|
||||
|
||||
w2_s = torch.nn.Parameter(torch.ones(E,
|
||||
H,
|
||||
_n_scale_cols(IN),
|
||||
dtype=scale_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w2_s, {
|
||||
"quant_method": "channel" if g == -1 else "group",
|
||||
**extra_weight_attrs
|
||||
})
|
||||
layer.register_parameter("w2_weight_scale", w2_s)
|
||||
|
||||
if self.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(E,
|
||||
2 * IN,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
# Placeholders for packed weights (will be replaced after packing)
|
||||
layer.register_parameter(
|
||||
"w13_weight_packed",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
|
||||
|
||||
layer.register_parameter(
|
||||
"w2_weight_packed",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
|
||||
|
||||
# dims for 4 bit fused matmuls
|
||||
layer.w13_in_features = H
|
||||
layer.w13_out_features = 2 * IN
|
||||
layer.w2_in_features = IN
|
||||
layer.w2_out_features = H
|
||||
layer.group_size = g
|
||||
|
||||
# post-load packing to dyn-4bit KleidiAI kernel's format
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
E = layer.w13_weight.shape[0]
|
||||
H = layer.w13_in_features
|
||||
I2 = layer.w13_out_features
|
||||
IN = layer.w2_in_features
|
||||
g = layer.group_size
|
||||
|
||||
def _pack_matrix(int4_as_int8_2d: torch.Tensor,
|
||||
scales_2d: torch.Tensor,
|
||||
bias_1d: Optional[torch.Tensor], in_features: int,
|
||||
out_features: int) -> torch.Tensor:
|
||||
# int4 values are stored as int8 in [-8,7].
|
||||
# Shift to unsigned nibble and pack pairs along input-dim.
|
||||
tmp = int4_as_int8_2d.add(8) # [out, in]
|
||||
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
|
||||
torch.uint8) # [out, in//2]
|
||||
|
||||
# KleidiAI groupwise kernels accepts float32 scales
|
||||
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||
scales = scales_2d.to(scale_dtype)
|
||||
bias = None if bias_1d is None else bias_1d.to(torch.float32)
|
||||
return torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_nibbles, scales, bias, g if g != -1 else in_features,
|
||||
in_features, out_features)
|
||||
|
||||
# Pack per expert
|
||||
w13_packed_list = []
|
||||
w2_packed_list = []
|
||||
|
||||
has_w13_bias = hasattr(layer,
|
||||
"w13_bias") and layer.w13_bias is not None
|
||||
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
|
||||
|
||||
for e in range(E):
|
||||
w13_packed_list.append(
|
||||
_pack_matrix(
|
||||
layer.w13_weight[e], # [2I, H]
|
||||
layer.w13_weight_scale[e], # [2I, H/g or 1]
|
||||
layer.w13_bias[e] if has_w13_bias else None, # [2I]
|
||||
H,
|
||||
I2))
|
||||
w2_packed_list.append(
|
||||
_pack_matrix(
|
||||
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
|
||||
layer.w2_weight[e], # [H, IN]
|
||||
layer.w2_weight_scale[e], # [H, IN/g or 1]
|
||||
layer.w2_bias[e] if has_w2_bias else None, # [H]
|
||||
IN,
|
||||
layer.w2_out_features # in_features=IN, out_features=H
|
||||
))
|
||||
|
||||
# each packed tensor has identical shape per expert; stack on dim 0
|
||||
w13_packed = torch.stack(w13_packed_list, dim=0)
|
||||
w2_packed = torch.stack(w2_packed_list, dim=0)
|
||||
|
||||
replace_parameter(layer, "w13_weight_packed",
|
||||
torch.nn.Parameter(w13_packed, requires_grad=False))
|
||||
replace_parameter(layer, "w2_weight_packed",
|
||||
torch.nn.Parameter(w2_packed, requires_grad=False))
|
||||
|
||||
# free raw tensors/scales/bias now that they're packed into the payload.
|
||||
replace_parameter(
|
||||
layer, "w13_weight",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, "w2_weight",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, "w13_weight_scale",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, "w2_weight_scale",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
if has_w13_bias:
|
||||
replace_parameter(
|
||||
layer, "w13_bias",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
if has_w2_bias:
|
||||
replace_parameter(
|
||||
layer, "w2_bias",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
# CPU dynamic 4-bit MoE path does not use modular kernels or
|
||||
# fused_experts; quant config is not needed.
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert activation in (
|
||||
"silu", "swigluoai",
|
||||
"swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||
assert expert_map is None, """expert_map/EP not implemented
|
||||
for CPU dyn-4bit MoE."""
|
||||
|
||||
def _act_kind(s: str) -> int:
|
||||
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
|
||||
if s == "swiglu":
|
||||
return 0
|
||||
if s == "swigluoai":
|
||||
return 1
|
||||
if s == "silu":
|
||||
return 2
|
||||
raise ValueError(f"Unknown activation '{s}'")
|
||||
|
||||
# Apply topk softmax on router output
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops._C.dynamic_4bit_int_moe(
|
||||
x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed,
|
||||
layer.w2_weight_packed, layer.w2_out_features,
|
||||
layer.w2_in_features, layer.w13_out_features, layer.group_size,
|
||||
apply_router_weight_on_input, int(_act_kind(activation)))
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt
|
||||
@ -75,7 +74,5 @@ def w8a8_deepgemm_block_scaled_mm_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="w8a8_deepgemm_block_scaled_mm",
|
||||
op_func=w8a8_deepgemm_block_scaled_mm,
|
||||
mutates_args=[],
|
||||
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -161,7 +161,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_mul_mat_gguf",
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
@ -273,7 +272,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="_fused_moe_gguf",
|
||||
op_func=_fused_moe_gguf,
|
||||
mutates_args=[],
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
@ -319,7 +317,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="_apply_gguf_embedding",
|
||||
op_func=_apply_gguf_embedding,
|
||||
mutates_args=[],
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
|
||||
@ -51,9 +51,7 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=rocm_aiter_gemm_w8a8_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
elif current_platform.is_rocm() or (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128)
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
elif current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64)
|
||||
|
||||
@ -91,9 +91,7 @@ if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()):
|
||||
@ -132,13 +130,14 @@ def _w8a8_triton_block_scaled_mm_fake(
|
||||
device=qx.device)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
mutates_args=[],
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
dispatch_key="CUDA",
|
||||
)
|
||||
# Note: the check can be removed when CPU torch > 2.7
|
||||
if not current_platform.is_cpu():
|
||||
direct_register_custom_op(
|
||||
"w8a8_triton_block_scaled_mm_func",
|
||||
_w8a8_triton_block_scaled_mm_func,
|
||||
fake_impl=_w8a8_triton_block_scaled_mm_fake,
|
||||
dispatch_key="CUDA",
|
||||
)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -21,6 +21,10 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
|
||||
value_layout_opts: dict[str, Any] = {}
|
||||
scale_layout_opts: dict[str, Any] = {}
|
||||
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and not is_torch_equal_or_newer("2.8.1")):
|
||||
@ -28,8 +32,15 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||
"this cause swizling to be disabled, which may "
|
||||
"cause performance degradation. Please upgrade to torch nightly")
|
||||
value_layout, value_layout_opts = StridedLayout, dict()
|
||||
scale_layout, scale_layout_opts = StridedLayout, dict()
|
||||
value_layout = StridedLayout
|
||||
scale_layout = StridedLayout
|
||||
elif current_platform.is_rocm():
|
||||
from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout,
|
||||
StridedLayout)
|
||||
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
value_layout = StridedLayout
|
||||
scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout
|
||||
else:
|
||||
value_layout, value_layout_opts = \
|
||||
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
||||
@ -113,7 +124,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
@ -124,7 +134,6 @@ try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
|
||||
@ -218,9 +218,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -147,5 +147,4 @@ direct_register_custom_op(
|
||||
op_func=_flashinfer_rotary_embedding,
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -136,9 +136,7 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_unquantized_gemm_impl",
|
||||
op_func=rocm_unquantized_gemm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal, Optional, Union, cast
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -347,12 +347,16 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
||||
pixel_values: torch.Tensor,
|
||||
**kwargs) -> torch.Tensor:
|
||||
target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
||||
image_features = vision_tower(pixel_values.to(dtype=target_dtype),
|
||||
**kwargs)
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
vision_tower: SiglipVisionModel,
|
||||
pixel_values: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
target_dtype: torch.dtype = \
|
||||
vision_tower.get_input_embeddings().weight.dtype
|
||||
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
|
||||
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
|
||||
|
||||
def select_features(leaf: torch.Tensor):
|
||||
return self._select_image_features(
|
||||
@ -360,10 +364,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
return cast(
|
||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
json_map_leaves(select_features, image_features),
|
||||
)
|
||||
return json_map_leaves(select_features, image_features)
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
|
||||
@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config
|
||||
config.max_seq_len_to_capture = config.max_model_len
|
||||
logger.info(
|
||||
"Setting max_seq_len_to_capture to %d "
|
||||
"to ensure that CUDA graph capture "
|
||||
"covers sequences of length up to max_model_len.",
|
||||
config.max_model_len)
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"MambaForCausalLM": MambaModelConfig,
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
|
||||
@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
|
||||
@ -141,9 +140,7 @@ def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
mutates_args=[],
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||
GeluAndMul,
|
||||
@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||
|
||||
from .interfaces import SupportsQuant
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
EPS = torch.tensor(torch.finfo().min)
|
||||
|
||||
|
||||
class Gemma3nAltUp(nn.Module):
|
||||
"""Alternating updates (Altup)
|
||||
@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
|
||||
return corrected_predictions
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class Gemma3nSelfDecoder(nn.Module):
|
||||
"""
|
||||
Includes altup embedding and self decoder layers
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma3nDecoderLayer],
|
||||
layer_idx_start: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
prefix=f"{prefix}.altup_projections.{idx-1}",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
self.altup_unembed_projections = nn.ModuleList([
|
||||
ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
|
||||
# Transformer blocks.
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma3nDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.eps = torch.tensor(torch.finfo().min)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
return self.embed_tokens_per_layer(
|
||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||
|
||||
def forward(
|
||||
def get_per_layer_inputs(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
hidden_states_0: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*hidden_states_0.shape[:-1],
|
||||
@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
)
|
||||
per_layer_projection = self.per_layer_projection_norm(
|
||||
per_layer_projection)
|
||||
|
||||
if per_layer_inputs is not None:
|
||||
# Profiling run does not compute per_layer_inputs
|
||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||
per_layer_inputs *= self.per_layer_input_scale
|
||||
else:
|
||||
per_layer_inputs = per_layer_projection
|
||||
return per_layer_inputs
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||
# Altup embed.
|
||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||
@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
new_magnitude, EPS)
|
||||
hidden_states = torch.stack(hidden_states, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
# Transformer blocks.
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
|
||||
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||
hidden_states_0, per_layer_inputs)
|
||||
hidden_states = self.altup_embed(hidden_states_0)
|
||||
|
||||
# [altnum_inputs, num_tokens, hidden_size]
|
||||
hidden_states = hidden_states.permute(2, 0, 1)
|
||||
|
||||
for idx, layer in enumerate(self.decoder_layers):
|
||||
layer_idx = idx + self.layer_idx_start
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# [num_tokens, hidden_size, altnum_inputs]
|
||||
hidden_states = hidden_states.permute(1, 2, 0)
|
||||
|
||||
return hidden_states, adjusted_per_layer_inputs
|
||||
|
||||
|
||||
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class Gemma3nCrossDecoder(nn.Module):
|
||||
"""
|
||||
Cross-decoder layers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
decoder_layers: list[Gemma3nDecoderLayer],
|
||||
layer_idx_start: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_layers = decoder_layers
|
||||
self.layer_idx_start = layer_idx_start
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
per_layer_inputs: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# [altnum_inputs, num_tokens, hidden_size]
|
||||
hidden_states = hidden_states.permute(2, 0, 1)
|
||||
for idx, layer in enumerate(self.decoder_layers):
|
||||
layer_idx = idx + self.layer_idx_start
|
||||
# [altup_num_inputs, num_tokens, hidden_size]
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||
**kwargs,
|
||||
)
|
||||
# [num_tokens, hidden_size, altnum_inputs]
|
||||
hidden_states = hidden_states.permute(1, 2, 0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# This disables torch.compile if --kv-sharing-fast-prefill passed
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.altup_unembed_projections = nn.ModuleList([
|
||||
ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
||||
) for idx in range(1, self.config.altup_num_inputs)
|
||||
])
|
||||
|
||||
# Allocate config.num_kv_shared_layers layers for self-decoder
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Gemma3nDecoderLayer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||
config.num_kv_shared_layers)
|
||||
|
||||
# NOTE(sarckk): importing this top level seems to cause issues
|
||||
# during running of tests.
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
|
||||
with set_model_tag("self_decoder"):
|
||||
self.self_decoder = Gemma3nSelfDecoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_decoder",
|
||||
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||
layer_idx_start=0,
|
||||
)
|
||||
# Layer idx 20-30 are cross-decoder layers in YOCO
|
||||
with set_model_tag("cross_decoder"):
|
||||
self.cross_decoder = Gemma3nCrossDecoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.cross_decoder",
|
||||
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||
layer_idx_start=first_kv_shared_layer_idx,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(
|
||||
config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||
|
||||
if self.fast_prefill_enabled:
|
||||
# Allocate static buffers for CUDAGraph
|
||||
# TODO(sarckk): Extract this functionality to interface
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
device = next(self.parameters()).device
|
||||
self.positions = torch.zeros(max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
self.hidden_states = torch.zeros(
|
||||
(max_num_tokens, config.hidden_size,
|
||||
self.config.altup_num_inputs),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.per_layer_inputs = torch.zeros(
|
||||
(max_num_tokens, self.config.num_hidden_layers,
|
||||
self.config.hidden_size_per_layer_input),
|
||||
dtype=self.embed_tokens.weight.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def embed_tokens(self):
|
||||
return self.self_decoder.embed_tokens
|
||||
|
||||
def get_per_layer_input_embeddings(
|
||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.get_input_embeddings(input_ids)
|
||||
|
||||
def fast_prefill_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
logits_indices_padded, num_logits_indices = None, None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
# attn_metadata is None during dummy runs
|
||||
if (self.fast_prefill_enabled and attn_metadata is not None):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
# Last layer is a KV sharing layer
|
||||
layer_attn_metadata = attn_metadata[
|
||||
self.layers[-1].self_attn.attn.layer_name]
|
||||
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
|
||||
logits_indices_padded = (
|
||||
layer_attn_metadata.logits_indices_padded)
|
||||
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||
|
||||
# Copy inputs for cudagraph
|
||||
batch_size = positions.size(0)
|
||||
self.positions[:batch_size].copy_(positions)
|
||||
self_decoder_hidden_states, per_layer_inputs_adjusted = \
|
||||
self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if logits_indices_padded is None:
|
||||
logits_indices_padded = torch.arange(
|
||||
positions.size(0),
|
||||
dtype=positions.dtype,
|
||||
device=positions.device,
|
||||
)
|
||||
|
||||
# NOTE(sarckk): There is currently a bug caused by
|
||||
# vLLM converting output of last piecewise CUDA graph
|
||||
# to weakref, causing memory to be prematurely freed
|
||||
# when there are multiple compilation units
|
||||
# Keep .clone() until fix in
|
||||
# https://github.com/vllm-project/vllm/pull/22282
|
||||
hidden_states = self_decoder_hidden_states.clone()
|
||||
|
||||
# Copy inputs for cudagraph
|
||||
num_padded_logits_indices = logits_indices_padded.size(0)
|
||||
self.positions[:num_padded_logits_indices].copy_(
|
||||
positions[logits_indices_padded])
|
||||
self.hidden_states[:num_padded_logits_indices].copy_(
|
||||
self_decoder_hidden_states[logits_indices_padded])
|
||||
self.per_layer_inputs[:num_padded_logits_indices].copy_(
|
||||
per_layer_inputs_adjusted[logits_indices_padded])
|
||||
cross_decoder_hidden_states = self.cross_decoder(
|
||||
positions=self.positions[:num_padded_logits_indices],
|
||||
hidden_states=self.hidden_states[:num_padded_logits_indices],
|
||||
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if num_logits_indices is not None:
|
||||
assert num_logits_indices > 0
|
||||
# Merge cross-decoder and self-decoder hidden states
|
||||
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||
cross_decoder_hidden_states[:num_logits_indices])
|
||||
else:
|
||||
hidden_states = cross_decoder_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def normal_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, per_layer_inputs = self.self_decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.cross_decoder(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
per_layer_inputs=per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def altup_unembed(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Altup unembed.
|
||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
||||
target_magnitude = torch.mean(hidden_states[..., 0]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[i])
|
||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
||||
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
|
||||
hidden_states[..., i])
|
||||
new_magnitude = torch.mean(hidden_states[..., i]**2,
|
||||
dim=-1,
|
||||
keepdim=True)**0.5
|
||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, self.eps)
|
||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=0)
|
||||
hidden_states[..., i] *= target_magnitude / torch.maximum(
|
||||
new_magnitude, EPS)
|
||||
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
|
||||
hidden_states = torch.mean(hidden_states, dim=-1)
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.fast_prefill_enabled:
|
||||
hidden_states = self.fast_prefill_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.normal_forward(
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
per_layer_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.altup_unembed(hidden_states)
|
||||
return self.norm(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# decoder layer weights, altup_unembed_projections and rmsnorm
|
||||
# are initialized in text model, others are in self decoder
|
||||
if (not name.startswith('layers')
|
||||
and not name.startswith('altup_unembed_projections')
|
||||
and not name.startswith('norm')):
|
||||
name = f"self_decoder.{name}"
|
||||
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
|
||||
@ -308,13 +308,11 @@ class GraniteModel(nn.Module):
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
hidden_states *= self.config.embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
@ -322,7 +320,6 @@ class GraniteModel(nn.Module):
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@ -475,10 +472,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
@ -298,17 +298,14 @@ class GraniteMoeModel(nn.Module):
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states *= self.embedding_multiplier
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
@ -523,10 +520,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
@ -195,17 +195,14 @@ class GraniteMoeSharedModel(nn.Module):
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states *= self.embedding_multiplier
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
@ -323,10 +320,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
@ -29,7 +29,6 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
@ -37,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
|
||||
Union, cast)
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -28,8 +27,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
@ -622,7 +623,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values)
|
||||
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
|
||||
vision_tower(pixel_values)
|
||||
|
||||
def select_features(leaf: torch.Tensor):
|
||||
return self._select_image_features(
|
||||
@ -630,10 +632,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
return cast(
|
||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
json_map_leaves(select_features, image_features),
|
||||
)
|
||||
return json_map_leaves(select_features, image_features)
|
||||
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user