mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 03:34:28 +08:00
Merge branch 'mlm-full-lora-support' of https://github.com/jeejeelee/vllm into mlm-full-lora-support
This commit is contained in:
commit
f67ccfae9c
@ -1,46 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Links for vLLM</h1/>
|
||||
<a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/>
|
||||
<a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--wheel", help="The wheel path.", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
filename = os.path.basename(args.wheel)
|
||||
|
||||
with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# sync the abi tag with .buildkite/scripts/upload-wheels.sh
|
||||
if "x86_64" in filename:
|
||||
x86_wheel = filename
|
||||
arm_wheel = filename.replace("x86_64", "aarch64").replace(
|
||||
"manylinux1", "manylinux2014"
|
||||
)
|
||||
elif "aarch64" in filename:
|
||||
x86_wheel = filename.replace("aarch64", "x86_64").replace(
|
||||
"manylinux2014", "manylinux1"
|
||||
)
|
||||
arm_wheel = filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported wheel: {filename}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(
|
||||
x86_wheel=x86_wheel,
|
||||
x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
|
||||
arm_wheel=arm_wheel,
|
||||
arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
|
||||
)
|
||||
)
|
||||
@ -7,13 +7,14 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import regex as re
|
||||
|
||||
if not sys.version_info >= (3, 12):
|
||||
raise RuntimeError("This script requires Python 3.12 or higher.")
|
||||
|
||||
|
||||
@ -74,6 +74,7 @@ FROM ${BASE_IMAGE_NAME}
|
||||
|
||||
# Define environments
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV SOC_VERSION="ascend910b1"
|
||||
|
||||
RUN pip config set global.index-url http://cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_PORT}/pypi/simple && \
|
||||
pip config set global.trusted-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local && \
|
||||
|
||||
@ -81,7 +81,7 @@ else
|
||||
alias_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg
|
||||
$PYTHON pip install regex && .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg
|
||||
|
||||
# copy indices to /<commit>/ unconditionally
|
||||
echo "Uploading indices to $S3_COMMIT_PREFIX"
|
||||
|
||||
@ -987,7 +987,8 @@ steps:
|
||||
commands:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 60min
|
||||
timeout_in_minutes: 120
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
@ -1011,7 +1012,8 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
- label: Multi-Modal Models Test (Extended) 3 # 75min
|
||||
timeout_in_minutes: 150
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
# grade: Blocking
|
||||
|
||||
@ -387,6 +387,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
- vllm/multimodal
|
||||
- examples/
|
||||
commands:
|
||||
- pip install tensorizer # for tensorizer test
|
||||
|
||||
@ -137,6 +137,7 @@ Compute Resources:
|
||||
- Alibaba Cloud
|
||||
- AMD
|
||||
- Anyscale
|
||||
- Arm
|
||||
- AWS
|
||||
- Crusoe Cloud
|
||||
- Databricks
|
||||
|
||||
120
benchmarks/benchmark_hash.py
Normal file
120
benchmarks/benchmark_hash.py
Normal file
@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
|
||||
|
||||
This focuses on a single test payload shaped like the prefix-cache hash input:
|
||||
(32-byte bytes object, 32-int tuple)
|
||||
|
||||
Usage:
|
||||
python benchmarks/hash_micro_benchmark.py --iterations 20000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from vllm.utils.hashing import sha256, xxhash
|
||||
|
||||
|
||||
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
|
||||
"""Generate a deterministic test payload."""
|
||||
random.seed(seed)
|
||||
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
|
||||
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
|
||||
return (bytes_data, int_tuple)
|
||||
|
||||
|
||||
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
|
||||
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
|
||||
times: list[float] = []
|
||||
|
||||
# Warm-up to avoid first-run noise.
|
||||
for _ in range(200):
|
||||
func(data)
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
func(data)
|
||||
end = time.perf_counter()
|
||||
times.append(end - start)
|
||||
|
||||
avg = statistics.mean(times)
|
||||
std = statistics.stdev(times) if len(times) > 1 else 0.0
|
||||
return avg, std
|
||||
|
||||
|
||||
def _run_benchmarks(
|
||||
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
|
||||
data: tuple,
|
||||
iterations: int,
|
||||
):
|
||||
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
|
||||
for name, func in benchmarks:
|
||||
try:
|
||||
avg, std = _benchmark_func(func, data, iterations)
|
||||
except ModuleNotFoundError as exc:
|
||||
print(f"Skipping {name}: {exc}")
|
||||
continue
|
||||
yield name, avg, std
|
||||
|
||||
|
||||
def builtin_hash(data: tuple) -> int:
|
||||
"""Wrapper for Python's built-in hash()."""
|
||||
return hash(data)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=10_000,
|
||||
help="Number of measured iterations per hash function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for test payload."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
data = _generate_test_data(args.seed)
|
||||
benchmarks = (
|
||||
("SHA256 (pickle)", sha256),
|
||||
("xxHash (pickle)", xxhash),
|
||||
("built-in hash()", builtin_hash),
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("HASH FUNCTION MICRO BENCHMARK")
|
||||
print("=" * 60)
|
||||
print("Test data: (32-byte bytes object, 32-int tuple)")
|
||||
print(f"Iterations: {args.iterations:,}")
|
||||
print("=" * 60)
|
||||
|
||||
results = list(_run_benchmarks(benchmarks, data, args.iterations))
|
||||
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
|
||||
|
||||
print("\nResults:")
|
||||
for name, avg, std in results:
|
||||
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
|
||||
|
||||
if builtin_entry:
|
||||
_, builtin_avg, _ = builtin_entry
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY (relative to built-in hash())")
|
||||
print("=" * 60)
|
||||
for name, avg, _ in results:
|
||||
if name == "built-in hash()":
|
||||
continue
|
||||
speed_ratio = avg / builtin_avg
|
||||
print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()")
|
||||
else:
|
||||
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
110
benchmarks/benchmark_prefix_block_hash.py
Normal file
110
benchmarks/benchmark_prefix_block_hash.py
Normal file
@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Simple benchmark to compare prefix-cache block hashing algorithms.
|
||||
|
||||
Example:
|
||||
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
|
||||
from vllm.utils.hashing import get_hash_fn_by_name
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
|
||||
|
||||
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
|
||||
|
||||
|
||||
def _generate_blocks(
|
||||
num_blocks: int, block_size: int, vocab_size: int, seed: int
|
||||
) -> list[list[int]]:
|
||||
rng = random.Random(seed)
|
||||
return [
|
||||
[rng.randrange(vocab_size) for _ in range(block_size)]
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
|
||||
|
||||
def _hash_all_blocks(
|
||||
hash_fn: Callable[[object], bytes],
|
||||
blocks: Iterable[Sequence[int]],
|
||||
) -> float:
|
||||
parent_hash: BlockHash | None = None
|
||||
start = time.perf_counter()
|
||||
for block in blocks:
|
||||
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def _benchmark(
|
||||
hash_algo: str,
|
||||
blocks: list[list[int]],
|
||||
trials: int,
|
||||
) -> tuple[float, float, float] | None:
|
||||
try:
|
||||
hash_fn = get_hash_fn_by_name(hash_algo)
|
||||
init_none_hash(hash_fn)
|
||||
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
|
||||
except ModuleNotFoundError as exc:
|
||||
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
avg = statistics.mean(timings)
|
||||
best = min(timings)
|
||||
# throughput: tokens / second
|
||||
tokens_hashed = len(blocks) * len(blocks[0])
|
||||
throughput = tokens_hashed / best
|
||||
return avg, best, throughput
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
|
||||
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
|
||||
parser.add_argument(
|
||||
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
|
||||
parser.add_argument(
|
||||
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--algorithms",
|
||||
nargs="+",
|
||||
default=SUPPORTED_ALGOS,
|
||||
choices=SUPPORTED_ALGOS,
|
||||
help="Hash algorithms to benchmark.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
blocks = _generate_blocks(
|
||||
args.num_blocks, args.block_size, args.vocab_size, args.seed
|
||||
)
|
||||
print(
|
||||
f"Benchmarking {len(args.algorithms)} algorithms on "
|
||||
f"{args.num_blocks} blocks (block size={args.block_size})."
|
||||
)
|
||||
|
||||
for algo in args.algorithms:
|
||||
result = _benchmark(algo, blocks, args.trials)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
avg, best, throughput = result
|
||||
print(
|
||||
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
|
||||
f"throughput: {throughput / 1e6:.2f}M tokens/s"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
from .utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
|
||||
GROUP_SIZE = 128
|
||||
FLOAT8_T = torch.float8_e4m3fn
|
||||
|
||||
|
||||
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
|
||||
print(
|
||||
f"Note : The timings reported above is for {cuda_graph_nops} "
|
||||
"consecutive invocations of the benchmarking functions. "
|
||||
f"Please divide by {cuda_graph_nops} for single invocation "
|
||||
"timings."
|
||||
)
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
|
||||
REFERENCE = 2
|
||||
|
||||
def get_impl(self):
|
||||
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor
|
||||
elif self == ImplType.REFERENCE:
|
||||
return reference
|
||||
raise ValueError(f"Unrecognized ImplType {self}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkTensors:
|
||||
input: torch.Tensor
|
||||
output: torch.Tensor
|
||||
|
||||
# Reference act output tensor
|
||||
ref_act_out: torch.Tensor
|
||||
ref_quant_out: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make(T: int, N: int) -> "BenchmarkTensors":
|
||||
assert T % GROUP_SIZE == 0
|
||||
assert N % (GROUP_SIZE * 2) == 0
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# silu_mul_per_token_group_quant_fp8_colmajor output.
|
||||
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
|
||||
FLOAT8_T
|
||||
)
|
||||
|
||||
# reference output.
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
ref_quant_out = torch.empty(
|
||||
(T, N // 2), dtype=torch.bfloat16, device="cuda"
|
||||
).to(FLOAT8_T)
|
||||
|
||||
return BenchmarkTensors(
|
||||
input=input,
|
||||
output=output,
|
||||
ref_act_out=ref_act_out,
|
||||
ref_quant_out=ref_quant_out,
|
||||
)
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return self.input.size(0)
|
||||
|
||||
@property
|
||||
def N(self):
|
||||
return self.input.size(1)
|
||||
|
||||
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
|
||||
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||
return {
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||
}
|
||||
elif impl_type == ImplType.REFERENCE:
|
||||
return {
|
||||
"input": self.input,
|
||||
"act_out": self.ref_act_out,
|
||||
"quant_out": self.ref_quant_out,
|
||||
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||
}
|
||||
raise ValueError(f"Unrecognized impl_type {impl_type}")
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
assert quant_out.size() == x.size()
|
||||
# Allocate the scale tensor column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_q = quant_out
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_T)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(
|
||||
input: torch.Tensor,
|
||||
act_out: torch.Tensor,
|
||||
quant_out: torch.Tensor,
|
||||
use_ue8m0: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
torch.ops._C.silu_and_mul(act_out, input)
|
||||
return reference_quant(act_out, quant_out, use_ue8m0)
|
||||
|
||||
|
||||
def bench_impl(
|
||||
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
|
||||
) -> TMeasurement:
|
||||
T = bench_tensors[0].T
|
||||
N = bench_tensors[0].N
|
||||
|
||||
arg_pool_size = len(bench_tensors)
|
||||
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
|
||||
|
||||
# warmup
|
||||
for kwargs in kwargs_list:
|
||||
impl_type.get_impl()(**kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Merge into a single kwargs and qualify arguments as ArgPool
|
||||
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
||||
for _kwargs in kwargs_list:
|
||||
for k, v in _kwargs.items():
|
||||
kwargs[k].values.append(v)
|
||||
|
||||
cuda_graph_params = None
|
||||
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
|
||||
timer = None
|
||||
with Bench(
|
||||
cuda_graph_params,
|
||||
"silu-mul-quant",
|
||||
f"num_tokens={T}, N={N}",
|
||||
impl_type.name,
|
||||
impl_type.get_impl(),
|
||||
**kwargs,
|
||||
) as bench:
|
||||
timer = bench.run()
|
||||
return timer
|
||||
|
||||
|
||||
def test_correctness(T: int, N: int):
|
||||
print(f"Testing num_tokens={T}, N={N} ...")
|
||||
|
||||
bench_tensor = BenchmarkTensors.make(T, N)
|
||||
|
||||
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
|
||||
|
||||
# reference output
|
||||
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
|
||||
|
||||
# test ouptut
|
||||
out_q, out_s = output_from_impl(
|
||||
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
|
||||
torch.testing.assert_close(ref_out_s, out_s)
|
||||
|
||||
|
||||
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
|
||||
timers = []
|
||||
for N, T in product(Ns, Ts):
|
||||
test_correctness(T, N)
|
||||
|
||||
bench_tensors: list[BenchmarkTensors] = [
|
||||
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
|
||||
]
|
||||
|
||||
silu_mul_quant_timer = bench_impl(
|
||||
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||
)
|
||||
timers.append(silu_mul_quant_timer)
|
||||
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
|
||||
timers.append(reference_timer)
|
||||
|
||||
print_timers(
|
||||
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
|
||||
)
|
||||
|
||||
print_timers(timers, cuda_graph_nops=arg_pool_size)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
|
||||
N = [2048, 4096, 8192]
|
||||
|
||||
print(f"T = {T}, N = {N}")
|
||||
run(T, N, arg_pool_size=8)
|
||||
@ -150,6 +150,97 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
#################### CSRC BUILD IMAGE ####################
|
||||
FROM base AS csrc-build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
# Use copy mode to avoid hardlink failures with Docker cache mounts
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY pyproject.toml setup.py CMakeLists.txt ./
|
||||
COPY cmake cmake/
|
||||
COPY csrc csrc/
|
||||
COPY vllm/envs.py vllm/envs.py
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
|
||||
# max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
|
||||
ARG SCCACHE_ENDPOINT
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Flag to control whether to use pre-built vLLM wheels
|
||||
ARG VLLM_USE_PRECOMPILED=""
|
||||
ARG VLLM_MERGE_BASE_COMMIT=""
|
||||
ARG VLLM_MAIN_CUDA_VERSION=""
|
||||
|
||||
# Use dummy version for csrc-build wheel (only .so files are extracted, version doesn't matter)
|
||||
ENV SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0+csrc.build"
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
|
||||
&& export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" \
|
||||
&& export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \
|
||||
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||
&& sccache --show-stats; \
|
||||
fi
|
||||
|
||||
ARG vllm_target_device="cuda"
|
||||
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
|
||||
export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" && \
|
||||
export VLLM_DOCKER_BUILD_CONTEXT=1 && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
#################### CSRC BUILD IMAGE ####################
|
||||
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
@ -172,66 +263,28 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY --from=csrc-build /workspace/dist /precompiled-wheels
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
# max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
|
||||
ARG SCCACHE_ENDPOINT
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Flag to control whether to use pre-built vLLM wheels
|
||||
ARG VLLM_USE_PRECOMPILED=""
|
||||
ARG VLLM_MAIN_CUDA_VERSION=""
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
|
||||
&& export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \
|
||||
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||
&& sccache --show-stats; \
|
||||
fi
|
||||
|
||||
ARG vllm_target_device="cuda"
|
||||
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
|
||||
export VLLM_DOCKER_BUILD_CONTEXT=1 && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
# Skip adding +precompiled suffix to version (preserves git-derived version)
|
||||
ENV VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX=1
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "${vllm_target_device}" = "cuda" ]; then \
|
||||
export VLLM_PRECOMPILED_WHEEL_LOCATION=$(ls /precompiled-wheels/*.whl); \
|
||||
fi && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REF
|
||||
@ -527,7 +580,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
else \
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0'
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3'
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 146 KiB After Width: | Height: | Size: 174 KiB |
@ -670,6 +670,35 @@ vllm bench serve \
|
||||
|
||||
</details>
|
||||
|
||||
### 🧪 Hashing Benchmarks
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
|
||||
|
||||
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
|
||||
```
|
||||
|
||||
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
|
||||
```
|
||||
|
||||
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
|
||||
|
||||
```bash
|
||||
uv pip install xxhash cbor2
|
||||
```
|
||||
|
||||
If an algorithm’s dependency is missing, the script will skip it and continue.
|
||||
|
||||
</details>
|
||||
|
||||
### ⚡ Request Prioritization Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
|
||||
@ -18,6 +18,7 @@ Compute Resources:
|
||||
- Alibaba Cloud
|
||||
- AMD
|
||||
- Anyscale
|
||||
- Arm
|
||||
- AWS
|
||||
- Crusoe Cloud
|
||||
- Databricks
|
||||
|
||||
@ -57,15 +57,15 @@ vLLM also provides [a reference example](../../examples/online_serving/prometheu
|
||||
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
|
||||
|
||||
- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds.
|
||||
- `vllm:prompt_tokens_total` - Prompt tokens.
|
||||
- `vllm:generation_tokens_total` - Generation tokens.
|
||||
- `vllm:prompt_tokens` - Prompt tokens.
|
||||
- `vllm:generation_tokens` - Generation tokens.
|
||||
- `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds.
|
||||
- `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds.
|
||||
- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states.
|
||||
- `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM.
|
||||
- `vllm:request_prompt_tokens` - Request prompt length.
|
||||
- `vllm:request_generation_tokens` - Request generation length.
|
||||
- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.
|
||||
- `vllm:request_success` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.
|
||||
- `vllm:request_queue_time_seconds` - Queue time.
|
||||
- `vllm:request_prefill_time_seconds` - Requests prefill time.
|
||||
- `vllm:request_decode_time_seconds` - Requests decode time.
|
||||
@ -571,9 +571,9 @@ model and then validate those tokens with the larger model.
|
||||
|
||||
- `vllm:spec_decode_draft_acceptance_rate` (Gauge)
|
||||
- `vllm:spec_decode_efficiency` (Gauge)
|
||||
- `vllm:spec_decode_num_accepted_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_accepted_tokens` (Counter)
|
||||
- `vllm:spec_decode_num_draft_tokens` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens` (Counter)
|
||||
|
||||
There is a PR under review (<https://github.com/vllm-project/vllm/pull/12193>) to add "prompt lookup (ngram)"
|
||||
speculative decoding to v1. Other techniques will follow. We should
|
||||
|
||||
@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|
||||
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
|
||||
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
|
||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
|
||||
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||
@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
|
||||
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
||||
|---------|-----------------------------------------|----------------------------------------------|
|
||||
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
||||
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
||||
|
||||
@ -54,7 +54,7 @@ th:not(:first-child) {
|
||||
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
|
||||
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
||||
|
||||
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
|
||||
\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
|
||||
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
|
||||
|
||||
### Feature x Hardware
|
||||
|
||||
58
docs/features/mooncake_connector_usage.md
Normal file
58
docs/features/mooncake_connector_usage.md
Normal file
@ -0,0 +1,58 @@
|
||||
# MooncakeConnector Usage Guide
|
||||
|
||||
## About Mooncake
|
||||
|
||||
Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine.
|
||||
|
||||
For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Installation
|
||||
|
||||
Install mooncake through pip: `uv pip install mooncake-transfer-engine`.
|
||||
|
||||
Refer to [Mooncake official repository](https://github.com/kvcache-ai/Mooncake) for more installation instructions
|
||||
|
||||
## Usage
|
||||
|
||||
### Prefiller Node (192.168.0.2)
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8010 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}'
|
||||
```
|
||||
|
||||
### Decoder Node (192.168.0.3)
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}'
|
||||
```
|
||||
|
||||
### Proxy
|
||||
|
||||
```bash
|
||||
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020
|
||||
```
|
||||
|
||||
> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future.
|
||||
|
||||
Now you can send requests to the proxy server through port 8000.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server
|
||||
- Default: 8998
|
||||
- Required only for prefiller instances
|
||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank
|
||||
- Used for the decoder notifying the prefiller
|
||||
|
||||
- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
|
||||
- Default: 480
|
||||
- If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
|
||||
|
||||
## KV Role Options
|
||||
|
||||
- **kv_producer**: For prefiller instances that generate KV caches
|
||||
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
|
||||
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
|
||||
@ -795,14 +795,12 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
image_embedding = torch.load(...)
|
||||
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(image_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||
base64_image_embedding = tensor2base64(image_embedding)
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
@ -4,9 +4,6 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must
|
||||
|
||||
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
@ -20,6 +17,8 @@ Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
|
||||
# --8<-- [end:set-up-using-python]
|
||||
# --8<-- [start:pre-built-wheels]
|
||||
|
||||
Currently, there are no pre-built Apple silicon CPU wheels.
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
|
||||
@ -78,6 +77,8 @@ uv pip install -e .
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm silicon CPU images.
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
|
||||
@ -1,11 +1,6 @@
|
||||
# --8<-- [start:installation]
|
||||
|
||||
vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform.
|
||||
|
||||
ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
vLLM offers basic model inferencing and serving on Arm CPU platform, with support NEON, data types FP32, FP16 and BF16.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
@ -20,6 +15,23 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
|
||||
# --8<-- [end:set-up-using-python]
|
||||
# --8<-- [start:pre-built-wheels]
|
||||
|
||||
Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries.
|
||||
Please replace `<version>` in the commands below with a specific version string (e.g., `0.11.2`).
|
||||
|
||||
```bash
|
||||
uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
|
||||
```
|
||||
|
||||
??? console "pip"
|
||||
```bash
|
||||
pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
|
||||
```
|
||||
|
||||
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
|
||||
|
||||
!!! note
|
||||
Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression).
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
|
||||
@ -69,6 +81,8 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built Arm CPU images.
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
```bash
|
||||
|
||||
@ -46,11 +46,25 @@ vLLM is a Python library that supports the following CPU variants. Select your C
|
||||
|
||||
### Pre-built wheels
|
||||
|
||||
Please refer to the instructions for [pre-built wheels on GPU](./gpu.md#pre-built-wheels).
|
||||
|
||||
When specifying the index URL, please make sure to use the `cpu` variant subdirectory.
|
||||
For example, the nightly build index is: `https://wheels.vllm.ai/nightly/cpu/`.
|
||||
|
||||
=== "Intel/AMD x86"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-wheels"
|
||||
|
||||
=== "ARM AArch64"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-wheels"
|
||||
|
||||
=== "Apple silicon"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-wheels"
|
||||
|
||||
=== "IBM Z (S390X)"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-wheels"
|
||||
|
||||
### Build wheel from source
|
||||
|
||||
#### Set up using Python-only build (without compilation) {#python-only-build}
|
||||
@ -87,6 +101,18 @@ VLLM_USE_PRECOMPILED=1 VLLM_PRECOMPILED_WHEEL_VARIANT=cpu VLLM_TARGET_DEVICE=cpu
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-images"
|
||||
|
||||
=== "ARM AArch64"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-images"
|
||||
|
||||
=== "Apple silicon"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-images"
|
||||
|
||||
=== "IBM Z (S390X)"
|
||||
|
||||
--8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-images"
|
||||
|
||||
### Build image from source
|
||||
|
||||
=== "Intel/AMD x86"
|
||||
|
||||
@ -4,9 +4,6 @@ vLLM has experimental support for s390x architecture on IBM Z platform. For now,
|
||||
|
||||
Currently, the CPU implementation for s390x architecture supports FP32 datatype only.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels or images for this device, so you must build vLLM from source.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
@ -21,6 +18,8 @@ Currently, the CPU implementation for s390x architecture supports FP32 datatype
|
||||
# --8<-- [end:set-up-using-python]
|
||||
# --8<-- [start:pre-built-wheels]
|
||||
|
||||
Currently, there are no pre-built IBM Z CPU wheels.
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
|
||||
@ -69,6 +68,8 @@ Execute the following commands to build and install vLLM from source.
|
||||
# --8<-- [end:build-wheel-from-source]
|
||||
# --8<-- [start:pre-built-images]
|
||||
|
||||
Currently, there are no pre-built IBM Z CPU images.
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
|
||||
|
||||
@ -17,6 +17,8 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
|
||||
# --8<-- [end:set-up-using-python]
|
||||
# --8<-- [start:pre-built-wheels]
|
||||
|
||||
Currently, there are no pre-built x86 CPU wheels.
|
||||
|
||||
# --8<-- [end:pre-built-wheels]
|
||||
# --8<-- [start:build-wheel-from-source]
|
||||
|
||||
|
||||
@ -5,9 +5,6 @@ vLLM supports AMD GPUs with ROCm 6.3 or above, and torch 2.8.0 and above.
|
||||
!!! tip
|
||||
[Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
|
||||
@ -2,9 +2,6 @@
|
||||
|
||||
vLLM initially supports basic model inference and serving on Intel GPU platform.
|
||||
|
||||
!!! warning
|
||||
There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions.
|
||||
|
||||
# --8<-- [end:installation]
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
|
||||
@ -711,7 +711,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ |
|
||||
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |
|
||||
|
||||
@ -23,31 +23,23 @@ def create_test_prompts(
|
||||
# this is an example of using quantization without LoRA
|
||||
(
|
||||
"My name is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
None,
|
||||
),
|
||||
# the next three examples use quantization with LoRA
|
||||
(
|
||||
"my name is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("lora-test-1", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of USA is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("lora-test-2", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of France is",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("lora-test-3", 1, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
@ -27,9 +27,7 @@ def create_test_prompts(
|
||||
return [
|
||||
(
|
||||
"A robot may not injure a human being",
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
None,
|
||||
),
|
||||
(
|
||||
@ -41,22 +39,12 @@ def create_test_prompts(
|
||||
),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("sql-lora", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
),
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("sql-lora2", 2, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
@ -28,13 +28,11 @@ Dependencies:
|
||||
- openai
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
@ -58,11 +56,7 @@ def main():
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
|
||||
# Prompt embeddings
|
||||
buffer = io.BytesIO()
|
||||
torch.save(prompt_embeds, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
|
||||
encoded_embeds = tensor2base64(prompt_embeds)
|
||||
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
|
||||
@ -150,7 +150,8 @@ def run_siglip(client: OpenAI, model: str):
|
||||
Start the server using:
|
||||
|
||||
vllm serve google/siglip-base-patch16-224 \
|
||||
--runner pooling
|
||||
--runner pooling \
|
||||
--chat-template template_basic.jinja
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
|
||||
@ -46,6 +46,7 @@ scipy # Required for phi-4-multimodal-instruct
|
||||
ninja # Required for xgrammar, rocm, tpu, xpu
|
||||
pybase64 # fast base64 implementation
|
||||
cbor2 # Required for cross-language serialization of hashable objects
|
||||
ijson # Required for mistral streaming tool parser
|
||||
setproctitle # Used to set process names for better debugging and monitoring
|
||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic == 0.71.0
|
||||
|
||||
@ -3,7 +3,6 @@ ninja
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<81.0.0
|
||||
setuptools-scm>=8
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
|
||||
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL)
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
|
||||
|
||||
# Dependencies for CPUs
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
|
||||
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
|
||||
|
||||
@ -42,6 +42,6 @@ tritonclient==2.51.0
|
||||
|
||||
numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
numpy
|
||||
runai-model-streamer[s3,gcs]==0.15.0
|
||||
runai-model-streamer[s3,gcs]==0.15.3
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||
|
||||
@ -12,7 +12,7 @@ tensorizer==2.10.1
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer[s3,gcs]==0.15.0
|
||||
runai-model-streamer[s3,gcs]==0.15.3
|
||||
conch-triton-kernels==1.2.1
|
||||
timm>=1.0.17
|
||||
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@d6f998a03432b2452f8de2bb5cefb5af9795d459
|
||||
|
||||
@ -51,7 +51,7 @@ tritonclient==2.51.0
|
||||
arctic-inference == 0.1.1 # Required for suffix decoding test
|
||||
numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
numpy
|
||||
runai-model-streamer[s3,gcs]==0.15.0
|
||||
runai-model-streamer[s3,gcs]==0.15.3
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||
decord==0.6.0
|
||||
|
||||
@ -965,11 +965,11 @@ rsa==4.9.1
|
||||
# via google-auth
|
||||
rtree==1.4.0
|
||||
# via torchgeo
|
||||
runai-model-streamer==0.15.0
|
||||
runai-model-streamer==0.15.3
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-gcs==0.15.0
|
||||
runai-model-streamer-gcs==0.15.3
|
||||
# via runai-model-streamer
|
||||
runai-model-streamer-s3==0.15.0
|
||||
runai-model-streamer-s3==0.15.3
|
||||
# via runai-model-streamer
|
||||
s3transfer==0.10.3
|
||||
# via boto3
|
||||
|
||||
41
setup.py
41
setup.py
@ -346,10 +346,13 @@ class precompiled_wheel_utils:
|
||||
The order of preference is:
|
||||
1. user-specified wheel location (can be either local or remote, via
|
||||
VLLM_PRECOMPILED_WHEEL_LOCATION)
|
||||
2. user-specified variant from nightly repo (current main commit via
|
||||
VLLM_PRECOMPILED_WHEEL_VARIANT)
|
||||
2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo
|
||||
3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo
|
||||
4. the default variant from nightly repo (current main commit)
|
||||
4. the default variant from nightly repo
|
||||
|
||||
If downloading from the nightly repo, the commit can be specified via
|
||||
VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch
|
||||
is used.
|
||||
"""
|
||||
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
|
||||
if wheel_location is not None:
|
||||
@ -362,10 +365,13 @@ class precompiled_wheel_utils:
|
||||
# try to fetch the wheel metadata from the nightly wheel repo
|
||||
main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "")
|
||||
variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant)
|
||||
commit = os.getenv(
|
||||
"VLLM_PRECOMPILED_WHEEL_COMMIT",
|
||||
precompiled_wheel_utils.get_base_commit_in_main_branch(),
|
||||
)
|
||||
commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower()
|
||||
if not commit or len(commit) != 40:
|
||||
print(
|
||||
f"VLLM_PRECOMPILED_WHEEL_COMMIT not valid: {commit}"
|
||||
", trying to fetch base commit in main branch"
|
||||
)
|
||||
commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||
print(f"Using precompiled wheel commit {commit} with variant {variant}")
|
||||
try_default = False
|
||||
wheels, repo_url, download_filename = None, None, None
|
||||
@ -461,14 +467,22 @@ class precompiled_wheel_utils:
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
]
|
||||
|
||||
compiled_regex = re.compile(
|
||||
flash_attn_regex = re.compile(
|
||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||
)
|
||||
triton_kernels_regex = re.compile(
|
||||
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||
)
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||
)
|
||||
file_members += list(
|
||||
filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
|
||||
filter(lambda x: flash_attn_regex.match(x.filename), wheel.filelist)
|
||||
)
|
||||
file_members += list(
|
||||
filter(
|
||||
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
|
||||
)
|
||||
)
|
||||
|
||||
for file in file_members:
|
||||
@ -494,10 +508,6 @@ class precompiled_wheel_utils:
|
||||
|
||||
@staticmethod
|
||||
def get_base_commit_in_main_branch() -> str:
|
||||
# Force to use the nightly wheel. This is mainly used for CI testing.
|
||||
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
|
||||
return "nightly"
|
||||
|
||||
try:
|
||||
# Get the latest commit hash of the upstream main branch.
|
||||
resp_json = subprocess.check_output(
|
||||
@ -508,6 +518,7 @@ class precompiled_wheel_utils:
|
||||
]
|
||||
).decode("utf-8")
|
||||
upstream_main_commit = json.loads(resp_json)["sha"]
|
||||
print(f"Upstream main branch latest commit: {upstream_main_commit}")
|
||||
|
||||
# In Docker build context, .git may be immutable or missing.
|
||||
if envs.VLLM_DOCKER_BUILD_CONTEXT:
|
||||
@ -648,7 +659,7 @@ def get_vllm_version() -> str:
|
||||
if envs.VLLM_TARGET_DEVICE == "empty":
|
||||
version += f"{sep}empty"
|
||||
elif _is_cuda():
|
||||
if envs.VLLM_USE_PRECOMPILED:
|
||||
if envs.VLLM_USE_PRECOMPILED and not envs.VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX:
|
||||
version += f"{sep}precompiled"
|
||||
else:
|
||||
cuda_version = str(get_nvcc_cuda_version())
|
||||
@ -786,7 +797,7 @@ setup(
|
||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.0"],
|
||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"],
|
||||
"audio": [
|
||||
"librosa",
|
||||
"soundfile",
|
||||
|
||||
@ -392,39 +392,48 @@ def test_pass_config_deprecation(caplog_vllm):
|
||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_norm_quant is True
|
||||
assert config.fuse_act_quant is True
|
||||
assert config.enable_fusion is None
|
||||
assert config.enable_fusion is True
|
||||
|
||||
# Test enable_attn_fusion -> fuse_attn_quant
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_attn_fusion=True)
|
||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_attn_quant is True
|
||||
assert config.enable_attn_fusion is None
|
||||
assert config.enable_attn_fusion is True
|
||||
|
||||
# Test enable_noop -> eliminate_noops
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_noop=True)
|
||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||
assert config.eliminate_noops is True
|
||||
assert config.enable_noop is None
|
||||
assert config.enable_noop is True
|
||||
|
||||
# Test enable_sequence_parallelism -> enable_sp
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_sequence_parallelism=True)
|
||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||
assert config.enable_sp is True
|
||||
assert config.enable_sequence_parallelism is None
|
||||
assert config.enable_sequence_parallelism is True
|
||||
|
||||
# Test enable_async_tp -> fuse_gemm_comms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_async_tp=True)
|
||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_gemm_comms is True
|
||||
assert config.enable_async_tp is None
|
||||
assert config.enable_async_tp is True
|
||||
|
||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||
caplog_vllm.clear()
|
||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||
assert config.fuse_allreduce_rms is True
|
||||
assert config.enable_fi_allreduce_fusion is None
|
||||
assert config.enable_fi_allreduce_fusion is True
|
||||
|
||||
# Test hash consistency
|
||||
config_old = PassConfig(enable_fusion=True)
|
||||
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
config_old = PassConfig(enable_async_tp=True)
|
||||
config_new = PassConfig(fuse_gemm_comms=True)
|
||||
assert config_old.compute_hash() == config_new.compute_hash()
|
||||
|
||||
@ -6,6 +6,7 @@ import lm_eval
|
||||
import pytest
|
||||
|
||||
from tests.utils import large_gpu_mark
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def get_model_args(
|
||||
@ -45,6 +46,12 @@ def get_model_args(
|
||||
return model_args
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="EPLB with Spec Decode is a work in progress on ROCm.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_setup",
|
||||
[
|
||||
|
||||
@ -69,9 +69,20 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
first_chunk = None
|
||||
chunk_count = 0
|
||||
async for chunk in resp:
|
||||
chunk_count += 1
|
||||
if first_chunk is None and chunk.type == "message_start":
|
||||
first_chunk = chunk
|
||||
print(chunk.model_dump_json())
|
||||
|
||||
assert chunk_count > 0
|
||||
assert first_chunk is not None, "message_start chunk was never observed"
|
||||
assert first_chunk.usage is not None, "first chunk should include usage stats"
|
||||
assert first_chunk.usage["output_tokens"] == 0
|
||||
assert first_chunk.usage["input_tokens"] > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):
|
||||
|
||||
@ -2,64 +2,47 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
DTYPE = "float16"
|
||||
|
||||
|
||||
def _terratorch_dummy_inputs(model_name: str):
|
||||
def _terratorch_dummy_messages():
|
||||
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
|
||||
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
|
||||
|
||||
buffer_tiff = io.BytesIO()
|
||||
torch.save(pixel_values, buffer_tiff)
|
||||
buffer_tiff.seek(0)
|
||||
binary_data = buffer_tiff.read()
|
||||
base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
buffer_coord = io.BytesIO()
|
||||
torch.save(location_coords, buffer_coord)
|
||||
buffer_coord.seek(0)
|
||||
binary_data = buffer_coord.read()
|
||||
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"additional_data": {"prompt_token_ids": [1]},
|
||||
"encoding_format": "base64",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"pixel_values": base64_tensor_embedding,
|
||||
"location_coords": base64_coord_embedding,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"pixel_values": tensor2base64(pixel_values),
|
||||
"location_coords": tensor2base64(location_coords),
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_request(model_name: str):
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||
)
|
||||
def test_single_request(model_name: str):
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"float16",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--max-num-seqs",
|
||||
@ -70,11 +53,15 @@ async def test_single_request(model_name: str):
|
||||
"--enable-mm-embeds",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as server:
|
||||
prompt = _terratorch_dummy_inputs(model_name)
|
||||
|
||||
# test single pooling
|
||||
response = requests.post(server.url_for("pooling"), json=prompt)
|
||||
with RemoteOpenAIServer(model_name, args) as server:
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": _terratorch_dummy_messages(),
|
||||
"encoding_format": "base64",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
output = response.json()["data"][0]["data"]
|
||||
|
||||
@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_encode_api(llm: LLM):
|
||||
# chunked prefill does not support all pooling
|
||||
err_msg = "pooling_task must be one of.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||
def test_token_classify(llm: LLM):
|
||||
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||
|
||||
|
||||
def test_score_api(llm: LLM):
|
||||
|
||||
@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
# token_classify uses ALL pooling, which does not support chunked prefill.
|
||||
task = "token_classify"
|
||||
input_text = ["This product was excellent and exceeded my expectations"]
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test",
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(
|
||||
f"Task {task} is not supported"
|
||||
)
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert len(poolings.data[0].data[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -42,7 +42,7 @@ def llm():
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_encode_api(llm: LLM):
|
||||
def test_token_embed(llm: LLM):
|
||||
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
|
||||
multi_vector = outputs[0].outputs.data
|
||||
assert multi_vector.shape == (11, 384)
|
||||
|
||||
@ -36,6 +36,13 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_config(llm: LLM):
|
||||
vllm_config = llm.llm_engine.vllm_config
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
assert vllm_config.scheduler_config.enable_chunked_prefill
|
||||
|
||||
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(use_activation):
|
||||
outputs = llm.reward(
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.multimodal.utils import (
|
||||
encode_video_base64,
|
||||
)
|
||||
from vllm.tokenizers import MistralTokenizer, get_tokenizer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import VLLM_PATH
|
||||
@ -85,11 +86,6 @@ def phi3v_model_config_image_embeds():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_tokenizer():
|
||||
return get_tokenizer(PHI3V_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen2_audio_model_config():
|
||||
return ModelConfig(
|
||||
@ -115,11 +111,6 @@ def audio_embeds_model_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen2_audio_tokenizer():
|
||||
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen25omni_model_config_mm_interleaved():
|
||||
return ModelConfig(
|
||||
@ -134,11 +125,6 @@ def qwen25omni_model_config_mm_interleaved():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen25omni_tokenizer():
|
||||
return get_tokenizer(QWEN25OMNI_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mistral_model_config():
|
||||
return ModelConfig(
|
||||
@ -150,11 +136,6 @@ def mistral_model_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
return get_tokenizer(MISTRAL_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_url():
|
||||
image = ImageAsset("cherry_blossom")
|
||||
@ -239,7 +220,6 @@ def _assert_mm_data_inputs(
|
||||
|
||||
def test_parse_chat_messages_single_image(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -253,7 +233,6 @@ def test_parse_chat_messages_single_image(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -266,7 +245,6 @@ def test_parse_chat_messages_single_image(
|
||||
|
||||
def test_parse_chat_messages_single_image_with_uuid(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -287,7 +265,6 @@ def test_parse_chat_messages_single_image_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -300,7 +277,6 @@ def test_parse_chat_messages_single_image_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -319,7 +295,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -332,7 +307,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -354,7 +328,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -367,7 +340,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_with_uuids(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -397,7 +369,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -413,7 +384,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
|
||||
|
||||
def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -439,7 +409,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -455,7 +424,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
|
||||
|
||||
def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -483,7 +451,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -500,7 +467,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -519,7 +485,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -533,7 +498,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -552,7 +516,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -566,7 +529,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -592,7 +554,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -609,7 +570,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid1 = "my_uuid_1"
|
||||
@ -635,7 +595,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -652,7 +611,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid2 = "my_uuid_2"
|
||||
@ -676,7 +634,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -692,7 +649,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
|
||||
|
||||
def test_parse_chat_messages_empty_system(
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
):
|
||||
# Test string format
|
||||
conversation, _, _ = parse_chat_messages(
|
||||
@ -704,7 +660,6 @@ def test_parse_chat_messages_empty_system(
|
||||
},
|
||||
],
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
assert conversation == [
|
||||
@ -722,7 +677,6 @@ def test_parse_chat_messages_empty_system(
|
||||
},
|
||||
],
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
assert conversation == [
|
||||
@ -734,7 +688,6 @@ def test_parse_chat_messages_empty_system(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_image_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -748,7 +701,6 @@ async def test_parse_chat_messages_single_image_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -761,7 +713,6 @@ async def test_parse_chat_messages_single_image_async(
|
||||
|
||||
def test_parse_chat_messages_multiple_images(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -779,7 +730,6 @@ def test_parse_chat_messages_multiple_images(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -795,7 +745,6 @@ def test_parse_chat_messages_multiple_images(
|
||||
|
||||
def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
uuid = "abcd"
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -809,7 +758,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -825,7 +773,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
uuid = "abcd"
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -839,7 +786,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -857,7 +803,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with UUID (no actual embeds data)."""
|
||||
uuid = "test-audio-uuid-123"
|
||||
@ -873,7 +818,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -889,11 +833,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
|
||||
def test_parse_chat_messages_audio_embeds_with_string(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with base64 string embedding data."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
@ -901,11 +842,7 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
|
||||
# Encode it as base64
|
||||
buffer = io.BytesIO()
|
||||
torch.save(audio_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
base64_audio_embedding = tensor2base64(audio_embedding)
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
@ -921,7 +858,6 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -939,11 +875,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_audio_embeds_async(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with async futures."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
@ -951,11 +884,7 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
|
||||
# Encode it as base64
|
||||
buffer = io.BytesIO()
|
||||
torch.save(audio_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
base64_audio_embedding = tensor2base64(audio_embedding)
|
||||
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
[
|
||||
@ -971,7 +900,6 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -990,7 +918,6 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
uuid = "abcd"
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -1004,7 +931,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1024,7 +950,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -1042,7 +967,6 @@ async def test_parse_chat_messages_multiple_images_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1058,7 +982,6 @@ async def test_parse_chat_messages_multiple_images_async(
|
||||
|
||||
def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1076,7 +999,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
assert conversation == [
|
||||
@ -1091,7 +1013,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
|
||||
def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1110,7 +1031,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1127,7 +1047,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1149,7 +1068,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
|
||||
},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1164,7 +1082,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -1195,7 +1112,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
|
||||
},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1210,7 +1126,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
|
||||
|
||||
def test_parse_chat_messages_context_text_format(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
@ -1222,7 +1137,6 @@ def test_parse_chat_messages_context_text_format(
|
||||
{"role": "user", "content": "What about this one?"},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
|
||||
@ -1246,7 +1160,6 @@ def test_parse_chat_messages_context_text_format(
|
||||
|
||||
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
@ -1277,14 +1190,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
@ -1322,14 +1233,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||
},
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1344,7 +1253,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
}
|
||||
],
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1360,7 +1268,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1380,7 +1287,6 @@ def test_parse_chat_messages_multiple_images_interleave(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1398,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_interleave(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages_futures(
|
||||
@ -1418,7 +1323,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1436,7 +1340,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -1465,7 +1368,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1482,7 +1384,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -1505,7 +1406,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
},
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1523,7 +1423,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
|
||||
|
||||
def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
image_uuid = str(hash(image_url))
|
||||
@ -1555,7 +1454,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
|
||||
},
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1573,7 +1471,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1601,7 +1498,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1627,7 +1523,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave(
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1671,7 +1566,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1699,7 +1593,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1743,7 +1636,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1775,7 +1667,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
|
||||
|
||||
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
image_url,
|
||||
video_url,
|
||||
audio_url,
|
||||
@ -1811,7 +1702,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
|
||||
},
|
||||
],
|
||||
qwen25omni_model_config_mm_interleaved,
|
||||
qwen25omni_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -1837,7 +1727,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
|
||||
|
||||
def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with pytest.raises(
|
||||
@ -1861,7 +1750,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
|
||||
}
|
||||
],
|
||||
phi3v_model_config_mm_interleaved,
|
||||
phi3v_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -2237,9 +2125,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
assert resolved_format == expected_format
|
||||
|
||||
|
||||
def test_parse_chat_messages_include_thinking_chunk(
|
||||
mistral_model_config, mistral_tokenizer
|
||||
):
|
||||
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@ -2269,7 +2155,6 @@ def test_parse_chat_messages_include_thinking_chunk(
|
||||
conversation_with_thinking, _, _ = parse_chat_messages(
|
||||
messages,
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
|
||||
@ -2353,7 +2238,6 @@ def test_apply_mistral_chat_template_thinking_chunk():
|
||||
|
||||
def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
audio_uuid = "abcd"
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@ -2371,7 +2255,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||
}
|
||||
],
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@ -2389,7 +2272,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
audio_uuid = "abcd"
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
@ -2407,7 +2289,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
|
||||
}
|
||||
],
|
||||
qwen2_audio_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
|
||||
@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
TritonOrDeepGemmExperts,
|
||||
standard_format,
|
||||
@ -457,10 +444,6 @@ def make_fused_experts(
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print(f"Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
|
||||
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
FLOAT8_DTYPE = torch.float8_e4m3fn
|
||||
GROUP_SIZE = 128
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
|
||||
|
||||
# Allocate the scale tensor in column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_DTYPE)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
T, N = x.size()
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
torch.ops._C.silu_and_mul(ref_act_out, x)
|
||||
return reference_quant(ref_act_out, use_ue8m0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("T", [128, 256, 512])
|
||||
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
|
||||
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
# Test
|
||||
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input, use_ue8m0=use_ue8m0
|
||||
)
|
||||
|
||||
# Reference
|
||||
ref_output, ref_output_scales = reference(input, use_ue8m0)
|
||||
|
||||
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
|
||||
torch.testing.assert_close(output_scales, ref_output_scales)
|
||||
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor import UniProcExecutor
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
# This is a dummy executor for patching in test_runai_model_streamer_s3.py.
|
||||
# We cannot use vllm_runner fixture here, because it spawns worker process.
|
||||
# The worker process reimports the patched entities, and the patch is not applied.
|
||||
class RunaiDummyExecutor(UniProcExecutor):
|
||||
def _init_executor(self) -> None:
|
||||
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
|
||||
|
||||
local_rank = 0
|
||||
rank = 0
|
||||
is_driver_worker = True
|
||||
|
||||
device_info = self.vllm_config.device_config.device.__str__().split(":")
|
||||
if len(device_info) > 1:
|
||||
local_rank = int(device_info[1])
|
||||
|
||||
worker_rpc_kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
wrapper_kwargs = {
|
||||
"vllm_config": self.vllm_config,
|
||||
}
|
||||
|
||||
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
|
||||
|
||||
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
|
||||
self.collective_rpc("init_device")
|
||||
@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from runai_model_streamer.safetensors_streamer.streamer_mock import StreamerPatcher
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
from .conftest import RunaiDummyExecutor
|
||||
|
||||
load_format = "runai_streamer"
|
||||
test_model = "openai-community/gpt2"
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files_s3_mocked_with_patch(
|
||||
vllm_runner,
|
||||
tmp_path: Path,
|
||||
monkeypatch,
|
||||
):
|
||||
patcher = StreamerPatcher(str(tmp_path))
|
||||
|
||||
test_mock_s3_model = "s3://my-mock-bucket/gpt2/"
|
||||
|
||||
# Download model from HF
|
||||
mock_model_dir = f"{tmp_path}/gpt2"
|
||||
snapshot_download(repo_id=test_model, local_dir=mock_model_dir)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.transformers_utils.runai_utils.runai_list_safetensors",
|
||||
patcher.shim_list_safetensors,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"vllm.transformers_utils.runai_utils.runai_pull_files",
|
||||
patcher.shim_pull_files,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"vllm.model_executor.model_loader.weight_utils.SafetensorsStreamer",
|
||||
patcher.create_mock_streamer,
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=test_mock_s3_model,
|
||||
load_format=load_format,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
executor = RunaiDummyExecutor(vllm_config)
|
||||
executor.driver_worker.load_model()
|
||||
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm import TokensPrompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["Qwen/Qwen3-Embedding-0.6B"],
|
||||
)
|
||||
@torch.inference_mode
|
||||
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
chunk_size = 10
|
||||
n_prompt_tokens = [55, 56, 57]
|
||||
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
max_model_len=128,
|
||||
max_num_batched_tokens=chunk_size,
|
||||
enforce_eager=True,
|
||||
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
auto_cls=AutoModel,
|
||||
) as hf_model:
|
||||
hf_outputs = []
|
||||
for token_prompt in token_prompts:
|
||||
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
|
||||
input_ids = inputs["input_ids"]
|
||||
output = hf_model.model(input_ids)
|
||||
hf_outputs.append(output.last_hidden_state.cpu().float()[0])
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_output,
|
||||
embeddings_1_lst=vllm_output,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
enable_chunked_prefill=False,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM tests."""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@ -14,6 +16,20 @@ def pytest_configure(config):
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
skip_patterns = ["test_granite_speech.py"]
|
||||
if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns):
|
||||
# Skip disabling SDP for Granite Speech tests on ROCm
|
||||
return
|
||||
|
||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||
# accuracy issues
|
||||
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
warnings.warn(
|
||||
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
|
||||
"to avoid HuggingFace Transformers accuracy issues",
|
||||
UserWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
@ -403,12 +403,13 @@ VLM_TEST_SETTINGS = {
|
||||
# So, we need to reduce the number of tokens for the test to pass.
|
||||
max_tokens=8,
|
||||
num_logprobs=10,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"glm4_1v": VLMTestInfo(
|
||||
models=["zai-org/GLM-4.1V-9B-Thinking"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>",
|
||||
prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>",
|
||||
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>",
|
||||
max_model_len=2048,
|
||||
@ -423,6 +424,7 @@ VLM_TEST_SETTINGS = {
|
||||
models=["zai-org/GLM-4.1V-9B-Thinking"],
|
||||
# GLM4.1V require include video metadata for input
|
||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||
prompt_formatter=lambda vid_prompt: f"[gMASK]<|user|>\n{vid_prompt}<|assistant|>\n", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
@ -737,7 +739,13 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
marks=[
|
||||
large_gpu_mark(min_gb=48),
|
||||
pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Model produces a vector of <UNK> output in HF on ROCm",
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen-VL"],
|
||||
|
||||
@ -8,6 +8,7 @@ from transformers import AutoModelForSpeechSeq2Seq
|
||||
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
@ -34,6 +35,12 @@ audio_lora_path = MODEL_NAME
|
||||
models = [MODEL_NAME]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_attention_backend_for_rocm(monkeypatch):
|
||||
if current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
@ -111,8 +118,12 @@ def run_test(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_model_len", [2048])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", ["float16"] if current_platform.is_rocm() else ["bfloat16"]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"max_model_len", [512] if current_platform.is_rocm() else [2048]
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
|
||||
@ -1,281 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
import librosa
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
|
||||
from ....conftest import (
|
||||
IMAGE_ASSETS,
|
||||
HfRunner,
|
||||
PromptAudioInput,
|
||||
PromptImageInput,
|
||||
VllmRunner,
|
||||
)
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
"cherry_blossom": "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
}
|
||||
)
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = (
|
||||
"<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
)
|
||||
|
||||
model_path = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
|
||||
)
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
speech_question = os.path.join(
|
||||
model_path, "examples", "what_is_shown_in_this_image.wav"
|
||||
)
|
||||
models = [model_path]
|
||||
|
||||
target_dtype = "half"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]],
|
||||
model: str,
|
||||
*,
|
||||
max_model_len: int,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(
|
||||
model,
|
||||
task="generate",
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=2,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
|
||||
enforce_eager=True,
|
||||
trust_remote_code=False,
|
||||
) as vllm_model:
|
||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
audios=audios,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
for prompts, images, audios in inputs
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_model.model.load_adapter(
|
||||
vision_lora_path,
|
||||
adapter_name="vision",
|
||||
)
|
||||
hf_processor = hf_model.processor
|
||||
eos_token_id = hf_processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
audios=audios,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
for prompts, images, audios in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_model_len", [12800])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
size_factors,
|
||||
dtype: str,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [
|
||||
(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
None,
|
||||
)
|
||||
for image, prompt in zip(images, HF_IMAGE_PROMPTS)
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_image,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
# [],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_model_len", [25600])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_multi_images_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
size_factors,
|
||||
dtype: str,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
(
|
||||
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[
|
||||
[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors
|
||||
],
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_model_len", [12800])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_vision_speech_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype: str,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
# use the example speech question so that the model outputs are reasonable
|
||||
audio = librosa.load(speech_question, sr=16000)
|
||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||
|
||||
inputs_vision_speech = [
|
||||
(
|
||||
["<|user|><|image|><|audio|><|end|><|assistant|>"],
|
||||
[image],
|
||||
[audio],
|
||||
),
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_vision_speech,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
@ -15,6 +15,7 @@ from transformers import AutoProcessor
|
||||
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
||||
from vllm.logprobs import Logprob, SampleLogprobs
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
@ -165,6 +166,15 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
|
||||
def test_chat(
|
||||
vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server
|
||||
) -> None:
|
||||
if (
|
||||
model == MISTRAL_SMALL_3_1_ID
|
||||
and max_model_len == 65536
|
||||
and current_platform.is_rocm()
|
||||
):
|
||||
pytest.skip(
|
||||
"OOM on ROCm: 24B model with 65536 context length exceeds GPU memory"
|
||||
)
|
||||
|
||||
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model])
|
||||
with vllm_runner(
|
||||
model,
|
||||
|
||||
@ -140,7 +140,7 @@ def video_with_metadata_glm4_1v():
|
||||
metadata = VIDEO_ASSETS[0].metadata
|
||||
question = "Describe the video."
|
||||
video_prompt = "<|begin_of_video|><|video|><|end_of_video|>"
|
||||
formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n"
|
||||
formatted_prompt = f"[gMASK]<|user|>\n{video_prompt}{question}<|assistant|>\n"
|
||||
|
||||
scales = [0.1, 0.2, 0.25]
|
||||
video_input = [
|
||||
|
||||
@ -25,6 +25,7 @@ from transformers import (
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
||||
@ -366,6 +367,40 @@ def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOut
|
||||
|
||||
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
|
||||
if current_platform.is_rocm():
|
||||
import types
|
||||
|
||||
config = hf_model.model.config
|
||||
if hasattr(config, "num_layers") and not hasattr(config, "num_hidden_layers"):
|
||||
config.num_hidden_layers = config.num_layers
|
||||
config.output_hidden_states = True
|
||||
|
||||
def patched_prepare_cache(
|
||||
self, generation_config, model_kwargs, *args, **kwargs
|
||||
):
|
||||
model_kwargs["past_key_values"] = None
|
||||
model_kwargs["use_cache"] = False
|
||||
return model_kwargs
|
||||
|
||||
hf_model.model._prepare_cache_for_generation = types.MethodType(
|
||||
patched_prepare_cache, hf_model.model
|
||||
)
|
||||
original_generate = hf_model.model.generate
|
||||
|
||||
def patched_generate(*args, **kwargs):
|
||||
kwargs["output_hidden_states"] = True
|
||||
kwargs["return_dict_in_generate"] = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
hf_model.model.generate = patched_generate
|
||||
original_forward = hf_model.model.forward
|
||||
|
||||
def patched_forward(*args, **kwargs):
|
||||
kwargs["output_hidden_states"] = True
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
hf_model.model.forward = patched_forward
|
||||
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
def processor(*args, text="", images=None, **kwargs):
|
||||
@ -406,7 +441,15 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
if videos is not None and is_list_of(videos, tuple):
|
||||
# If videos is a list of tuples, we assume each tuple contains
|
||||
# (video_array, metadata) as in the case of GLM4.1V.
|
||||
video_metadata = [[VideoMetadata(**video[1])] for video in videos]
|
||||
# Filter out 'do_sample_frames' as it's not a valid VideoMetadata arg
|
||||
video_metadata = [
|
||||
[
|
||||
VideoMetadata(
|
||||
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
|
||||
)
|
||||
]
|
||||
for video in videos
|
||||
]
|
||||
videos = [[video[0]] for video in videos]
|
||||
else:
|
||||
video_metadata = None
|
||||
|
||||
24
tests/models/multimodal/pooling/conftest.py
Normal file
24
tests/models/multimodal/pooling/conftest.py
Normal file
@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Pytest configuration for vLLM pooling tests."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm."""
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
siglip_tests = [item for item in items if "test_siglip" in item.nodeid]
|
||||
|
||||
if siglip_tests:
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
warnings.warn(
|
||||
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
|
||||
UserWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
@ -396,28 +396,6 @@ def test_processing_correctness(
|
||||
)
|
||||
|
||||
|
||||
# Phi4MultimodalForCausalLM share same model repo with original format
|
||||
# Phi4MMForCausalLM, so we add it as a separate test case
|
||||
# Remove this test after conversion PR merged:
|
||||
# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70
|
||||
@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
def test_processing_correctness_phi4_multimodal(
|
||||
model_arch: str,
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
_test_processing_correctness(
|
||||
model_arch,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
)
|
||||
|
||||
|
||||
def _assert_inputs_equal(
|
||||
a: MultiModalInputs,
|
||||
b: MultiModalInputs,
|
||||
|
||||
@ -667,6 +667,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
|
||||
trust_remote_code=True,
|
||||
max_transformers_version="4.53.3",
|
||||
transformers_version_reason="HF model uses deprecated transformers API "
|
||||
"(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: "
|
||||
"https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31",
|
||||
),
|
||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||
"lightonai/LightOnOCR-1B",
|
||||
@ -767,10 +771,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Phi4MMForCausalLM": _HfExamplesInfo(
|
||||
"microsoft/Phi-4-multimodal-instruct", trust_remote_code=True
|
||||
),
|
||||
"Phi4MultimodalForCausalLM": _HfExamplesInfo(
|
||||
"microsoft/Phi-4-multimodal-instruct",
|
||||
revision="refs/pr/70",
|
||||
),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo(
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
extras={
|
||||
|
||||
@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods:
|
||||
"""Test the is_reasoning_end method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
end_token_id = parser.end_token_id
|
||||
|
||||
start_token_id = parser.start_token_id
|
||||
# Test with end token present
|
||||
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
|
||||
|
||||
@ -122,6 +122,16 @@ class TestBaseThinkingReasoningParserMethods:
|
||||
# Test with empty list
|
||||
assert parser.is_reasoning_end([]) is False
|
||||
|
||||
# Test with interleaved thinking
|
||||
assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True
|
||||
assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False
|
||||
assert (
|
||||
parser.is_reasoning_end(
|
||||
[1, start_token_id, 2, end_token_id, 2, 2, start_token_id]
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_extract_content_ids(self, test_tokenizer):
|
||||
"""Test the extract_content_ids method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
@ -5,6 +5,10 @@
|
||||
set -e
|
||||
set -x
|
||||
|
||||
merge_base_commit=$(git merge-base HEAD origin/main)
|
||||
echo "Current merge base commit with main: $merge_base_commit"
|
||||
git show --oneline -s $merge_base_commit
|
||||
|
||||
cd /vllm-workspace/
|
||||
|
||||
# uninstall vllm
|
||||
@ -18,7 +22,7 @@ apt autoremove -y
|
||||
|
||||
echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py
|
||||
|
||||
VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e .
|
||||
VLLM_PRECOMPILED_WHEEL_COMMIT=$merge_base_commit VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e .
|
||||
|
||||
# Run the script
|
||||
python3 -c 'import vllm'
|
||||
|
||||
@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with all pooling does not support chunked prefill.",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with all pooling does not support prefix caching.",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support prefix caching.",
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
|
||||
@ -365,3 +365,54 @@ class TestEnvSetWithChoices:
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
|
||||
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
|
||||
assert env_func() == {"option1", "option2"}
|
||||
|
||||
|
||||
class TestVllmConfigureLogging:
|
||||
"""Test cases for VLLM_CONFIGURE_LOGGING environment variable."""
|
||||
|
||||
def test_configure_logging_defaults_to_true(self):
|
||||
"""Test that VLLM_CONFIGURE_LOGGING defaults to True when not set."""
|
||||
# Ensure the env var is not set
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
if "VLLM_CONFIGURE_LOGGING" in os.environ:
|
||||
del os.environ["VLLM_CONFIGURE_LOGGING"]
|
||||
|
||||
# Clear cache if it exists
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
result = envs.VLLM_CONFIGURE_LOGGING
|
||||
assert result is True
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_configure_logging_with_zero_string(self):
|
||||
"""Test that VLLM_CONFIGURE_LOGGING='0' evaluates to False."""
|
||||
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "0"}):
|
||||
# Clear cache if it exists
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
result = envs.VLLM_CONFIGURE_LOGGING
|
||||
assert result is False
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_configure_logging_with_one_string(self):
|
||||
"""Test that VLLM_CONFIGURE_LOGGING='1' evaluates to True."""
|
||||
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "1"}):
|
||||
# Clear cache if it exists
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
result = envs.VLLM_CONFIGURE_LOGGING
|
||||
assert result is True
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_configure_logging_with_invalid_value_raises_error(self):
|
||||
"""Test that invalid VLLM_CONFIGURE_LOGGING value raises ValueError."""
|
||||
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "invalid"}):
|
||||
# Clear cache if it exists
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
with pytest.raises(ValueError, match="invalid literal for int"):
|
||||
_ = envs.VLLM_CONFIGURE_LOGGING
|
||||
|
||||
847
tests/tool_use/test_mistral_tool_parser.py
Normal file
847
tests/tool_use/test_mistral_tool_parser.py
Normal file
@ -0,0 +1,847 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import AssistantMessage
|
||||
from mistral_common.protocol.instruct.request import InstructRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
|
||||
from vllm.tokenizers import (
|
||||
MistralTokenizer,
|
||||
TokenizerLike,
|
||||
get_tokenizer,
|
||||
)
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_pre_v11_tokenizer():
|
||||
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
||||
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
|
||||
return MistralToolParser(mistral_pre_v11_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mistral_tool_parser(mistral_tokenizer):
|
||||
return MistralToolParser(mistral_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) == 9
|
||||
|
||||
if isinstance(actual_tool_call, ToolCall):
|
||||
assert actual_tool_call.type == "function"
|
||||
elif isinstance(actual_tool_call, DeltaToolCall):
|
||||
assert actual_tool_call.function is not None
|
||||
assert actual_tool_call.function.name is not None
|
||||
assert actual_tool_call.function.arguments is not None
|
||||
assert actual_tool_call.function is not None
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name, (
|
||||
f"got wrong function name:${actual_tool_call.function.name}"
|
||||
)
|
||||
assert (
|
||||
actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
), f"got wrong function argument:${actual_tool_call.function.arguments}"
|
||||
|
||||
|
||||
def fix_tool_call_tokenization(
|
||||
tokens: list[int],
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
mistral_tokenizer: TokenizerLike,
|
||||
):
|
||||
"""
|
||||
Replaces the textual token sequence for [TOOL_CALLS]
|
||||
with its single special token ID.
|
||||
"""
|
||||
textual_tool_call_token_ids = mistral_tokenizer.encode(
|
||||
text=mistral_tool_parser.bot_token,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
|
||||
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
|
||||
|
||||
# If the input is too short to contain the sequence, no replacement is possible
|
||||
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
|
||||
return tokens
|
||||
|
||||
result_tokens = []
|
||||
i = 0
|
||||
target_len = len(textual_tool_call_token_ids)
|
||||
|
||||
while i < len(tokens):
|
||||
# Check if the slice from the current position matches the target sequence
|
||||
if tokens[i : i + target_len] == textual_tool_call_token_ids:
|
||||
# If it matches, add the replacement and jump the index forward
|
||||
result_tokens.extend(special_tool_call_token_ids)
|
||||
i += target_len
|
||||
else:
|
||||
# Otherwise, just add the current token and move to the next one
|
||||
result_tokens.append(tokens[i])
|
||||
i += 1
|
||||
|
||||
return result_tokens
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
mistral_tokenizer: TokenizerLike,
|
||||
model_output: str | None,
|
||||
tools: list[tuple[str, str]] | None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
if (
|
||||
isinstance(mistral_tokenizer, MistralTokenizer)
|
||||
and mistral_tokenizer.version >= 11
|
||||
):
|
||||
# With the newer versions of the tokenizer,
|
||||
# we cannot tokenize free text
|
||||
# so we need to create a list of messages to get tokenized
|
||||
assert tools is not None
|
||||
assistant_msg = AssistantMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=arg,
|
||||
)
|
||||
)
|
||||
for (name, arg) in tools
|
||||
],
|
||||
)
|
||||
request = InstructRequest(
|
||||
messages=[assistant_msg],
|
||||
)
|
||||
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
|
||||
else:
|
||||
# Older versions of the tokenizer are
|
||||
# able to encode directly the model's output (free text) into tokens
|
||||
assert model_output is not None
|
||||
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
||||
)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=mistral_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=None, # type: ignore[arg-type]
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
mistral_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def _test_extract_tool_calls_streaming(
|
||||
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
|
||||
):
|
||||
other_content: str = ""
|
||||
function_names: list[str] = []
|
||||
function_args_strs: list[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_ids: list[str | None] = []
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
tool_parser, tokenizer, model_output, tools
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
streamed_tool_calls = delta_message.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
|
||||
assert len(tool_parser.prev_tool_call_arr) > 0
|
||||
|
||||
# if a new tool is being called, set up empty arguments
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = tool_call.index
|
||||
function_args_strs.append("")
|
||||
tool_call_ids.append(None)
|
||||
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id and not tool_call_ids[tool_call.index]:
|
||||
tool_call_ids[tool_call.index] = tool_call.id
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
# if the function name is defined, set it. it should be streamed
|
||||
# IN ENTIRETY, exactly one time.
|
||||
if tool_call.function.name:
|
||||
assert isinstance(tool_call.function.name, str)
|
||||
function_names.append(tool_call.function.name)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
function_args_strs[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert other_content == expected_content
|
||||
|
||||
actual_tool_calls = [
|
||||
ToolCall(
|
||||
id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR
|
||||
),
|
||||
),
|
||||
)
|
||||
for tool_call_id, function_name, function_args_str in zip(
|
||||
tool_call_ids, function_names, function_args_strs
|
||||
)
|
||||
]
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
_test_extract_tool_calls_streaming(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
None,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["tools", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
[("add", '{"a": 3, "b": 4}')],
|
||||
# [TOOL_CALLS]add{"a": 3, "b": 4}
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
[("add_two_strings", '{"a": "3", "b": "4"}')],
|
||||
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_two_strings",
|
||||
arguments=json.dumps({"a": "3", "b": "4"}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
[
|
||||
("add", '{"a": 3.5, "b": 4}'),
|
||||
(
|
||||
"get_current_weather",
|
||||
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
|
||||
),
|
||||
],
|
||||
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
tools,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
_test_extract_tool_calls_streaming(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
None,
|
||||
tools,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
"content_before_tool",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
# Additional content should not be after the tool calls
|
||||
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add_this_and_that",
|
||||
arguments=json.dumps({"a": 3.5, "b": 4}),
|
||||
)
|
||||
)
|
||||
],
|
||||
"bla",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_one_chunk(
|
||||
mistral_tool_parser,
|
||||
mistral_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
if isinstance(mistral_tokenizer, MistralTokenizer):
|
||||
all_token_ids = mistral_tokenizer.encode(model_output)
|
||||
else:
|
||||
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_tool_parser, mistral_tokenizer
|
||||
)
|
||||
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=all_token_ids,
|
||||
delta_token_ids=all_token_ids,
|
||||
request=None,
|
||||
) # type: ignore[arg-type]
|
||||
assert isinstance(delta_message, DeltaMessage)
|
||||
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
||||
|
||||
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
||||
|
||||
if delta_message.content is None:
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"no_tools",
|
||||
"single_tool_add",
|
||||
"single_tool_add_strings",
|
||||
"single_tool_weather",
|
||||
"argument_before_name",
|
||||
"argument_before_name_and_name_in_argument",
|
||||
"multiple_tools",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3, "b": 4})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": "3", "b": "4"})
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_age",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"name": "John Doe",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
|
||||
mistral_pre_v11_tool_parser,
|
||||
mistral_pre_v11_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
|
||||
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
|
||||
else:
|
||||
all_token_ids = mistral_pre_v11_tokenizer.encode(
|
||||
model_output, add_special_tokens=False
|
||||
)
|
||||
all_token_ids = fix_tool_call_tokenization(
|
||||
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
|
||||
)
|
||||
|
||||
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=all_token_ids,
|
||||
delta_token_ids=all_token_ids,
|
||||
request=None,
|
||||
) # type: ignore[arg-type]
|
||||
assert isinstance(delta_message, DeltaMessage)
|
||||
assert len(delta_message.tool_calls) == len(expected_tool_calls)
|
||||
|
||||
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
|
||||
|
||||
if delta_message.content is None:
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
"mistral": {
|
||||
"mistral-7b": {
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
"supports_parallel": True,
|
||||
},
|
||||
"mistral-small-3.2": {
|
||||
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
"--tokenizer-mode",
|
||||
"mistral",
|
||||
"--config-format",
|
||||
"mistral",
|
||||
"--load-format",
|
||||
"mistral",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally.",
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
# FIXME: This test currently fails, need to debug why.
|
||||
# "granite20b": {
|
||||
|
||||
@ -11,7 +11,9 @@ PROMPTS = [
|
||||
]
|
||||
|
||||
|
||||
def test_reset_prefix_cache_e2e():
|
||||
def test_reset_prefix_cache_e2e(monkeypatch):
|
||||
# "spawn" is required for test to be deterministic
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
gpu_memory_utilization=0.2,
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.hashing import _xxhash
|
||||
|
||||
|
||||
def test_prefix_caching_from_cli():
|
||||
@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
|
||||
def test_prefix_caching_xxhash_from_cli():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# set hash algorithm to xxhash (pickle)
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
|
||||
|
||||
# set hash algorithm to xxhash_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
@ -13,6 +11,7 @@ from transformers import AutoConfig
|
||||
|
||||
from tests.conftest import ImageTestAssets
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.utils.serial_utils import tensor2base64
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds):
|
||||
yield async_client
|
||||
|
||||
|
||||
def encode_image_embedding_to_base64(image_embedding) -> str:
|
||||
"""
|
||||
Encode image embedding to base64 string
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(image_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_image_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
return base64_image_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
|
||||
@ -73,7 +60,7 @@ async def test_completions_with_image_embeds(
|
||||
):
|
||||
# Test case: Single image embeds input
|
||||
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
|
||||
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
|
||||
base64_image_embedding = tensor2base64(image_embeds)
|
||||
chat_completion = await client_with_image_embeds.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
||||
@ -3,12 +3,14 @@
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
|
||||
|
||||
@ -108,6 +110,13 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason=(
|
||||
"hipErrorLaunchFailure when running this test, see issue:"
|
||||
"https://github.com/ROCm/pytorch/issues/2822"
|
||||
),
|
||||
)
|
||||
def test_shared_storage_connector_hashes(tmp_path):
|
||||
"""
|
||||
Tests that SharedStorageConnector saves KV to the storage locations
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
POOLING_MODEL_NAME,
|
||||
@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
|
||||
CustomLogitprocSource,
|
||||
DummyLogitsProcessor,
|
||||
WrappedPerReqLogitsProcessor,
|
||||
dummy_module,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
|
||||
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||
# Inject dummy module which defines logitproc
|
||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||
# Scenario: load logitproc from provided class object
|
||||
|
||||
@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
TEMP_GREEDY,
|
||||
dummy_module,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
|
||||
main.main()
|
||||
|
||||
|
||||
def _server_with_logitproc_module(
|
||||
def _server_with_logitproc_fqcn(
|
||||
env_dict: dict[str, str] | None,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
) -> None:
|
||||
"""Start vLLM server, inject module with dummy logitproc"""
|
||||
|
||||
# Patch `modules` to inject dummy logitproc module
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
|
||||
if request.param:
|
||||
# Launch server, append FQCN argument, inject dummy logitproc module
|
||||
args = default_server_args + request.param
|
||||
_server_fxn = _server_with_logitproc_module
|
||||
_server_fxn = _server_with_logitproc_fqcn
|
||||
else:
|
||||
# Launch server, inject dummy logitproc entrypoint
|
||||
args = default_server_args
|
||||
|
||||
@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
|
||||
TEMP_GREEDY = 0.0
|
||||
MAX_TOKENS = 20
|
||||
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
|
||||
DUMMY_LOGITPROC_MODULE = "DummyModule"
|
||||
DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
|
||||
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
||||
pytest.param(
|
||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||
marks=pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="The tests are skipped on rocm platform.",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@ -761,6 +761,10 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
)
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
"""
|
||||
The GPU model runner creates different views into the
|
||||
|
||||
@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake(
|
||||
pass
|
||||
|
||||
|
||||
# Cache whether aiter supports FP8 MLA parameters
|
||||
_AITER_MLA_SUPPORTS_FP8: bool | None = None
|
||||
|
||||
|
||||
def _check_aiter_mla_fp8_support() -> bool:
|
||||
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
|
||||
global _AITER_MLA_SUPPORTS_FP8
|
||||
if _AITER_MLA_SUPPORTS_FP8 is None:
|
||||
try:
|
||||
import inspect
|
||||
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
sig = inspect.signature(mla_decode_fwd)
|
||||
_AITER_MLA_SUPPORTS_FP8 = (
|
||||
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
|
||||
)
|
||||
except Exception:
|
||||
_AITER_MLA_SUPPORTS_FP8 = False
|
||||
return _AITER_MLA_SUPPORTS_FP8
|
||||
|
||||
|
||||
def _rocm_aiter_mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
kwargs = {
|
||||
"sm_scale": sm_scale,
|
||||
"logit_cap": logit_cap,
|
||||
}
|
||||
|
||||
# Only pass q_scale and kv_scale if the aiter library supports them
|
||||
if _check_aiter_mla_fp8_support():
|
||||
kwargs["q_scale"] = q_scale
|
||||
kwargs["kv_scale"] = kv_scale
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
q_scale=q_scale,
|
||||
kv_scale=kv_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -104,7 +104,8 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif (
|
||||
at_target
|
||||
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
|
||||
and at_target
|
||||
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
):
|
||||
mutated_args = {
|
||||
|
||||
@ -30,7 +30,7 @@ CacheDType = Literal[
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
|
||||
@ -77,9 +77,21 @@ class CacheConfig:
|
||||
"""Whether to enable prefix caching."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "sha256" uses Pickle for object serialization before hashing.\n
|
||||
- "sha256" uses Pickle for object serialization before hashing. This is the
|
||||
current default, as SHA256 is the most secure choice to avoid potential
|
||||
hash collisions.\n
|
||||
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256."""
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256.\n
|
||||
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
|
||||
non-cryptographic hashing. Requires the optional ``xxhash`` package.
|
||||
IMPORTANT: Use of a hashing algorithm that is not considered
|
||||
cryptographically secure theoretically increases the risk of hash collisions,
|
||||
which can cause undefined behavior or even leak private information in
|
||||
multi-tenant environments. Even if collisions are still very unlikely, it is
|
||||
important to consider your security risk tolerance against the performance
|
||||
benefits before turning this on.\n
|
||||
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
|
||||
reproducible hashing. Requires the optional ``xxhash`` package."""
|
||||
cpu_offload_gb: float = Field(default=0, ge=0)
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import enum
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, field
|
||||
from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
|
||||
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import config, handle_deprecated
|
||||
from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@ -196,7 +196,16 @@ class PassConfig:
|
||||
Any new fields that affect compilation should be added to the hash.
|
||||
Any future fields that don't affect compilation should be excluded.
|
||||
"""
|
||||
return InductorPass.hash_dict(asdict(self))
|
||||
|
||||
ignored_fields = [
|
||||
"enable_fusion",
|
||||
"enable_attn_fusion",
|
||||
"enable_noop",
|
||||
"enable_sequence_parallelism",
|
||||
"enable_async_tp",
|
||||
"enable_fi_allreduce_fusion",
|
||||
]
|
||||
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
|
||||
|
||||
@field_validator(
|
||||
"fuse_norm_quant",
|
||||
@ -267,14 +276,6 @@ class PassConfig:
|
||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||
)
|
||||
|
||||
# Force old flags to None to ensure they are not used
|
||||
self.enable_fusion = None
|
||||
self.enable_attn_fusion = None
|
||||
self.enable_noop = None
|
||||
self.enable_sequence_parallelism = None
|
||||
self.enable_async_tp = None
|
||||
self.enable_fi_allreduce_fusion = None
|
||||
|
||||
if not self.eliminate_noops:
|
||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||
logger.warning_once(
|
||||
|
||||
@ -84,7 +84,7 @@ TaskOption = Literal[
|
||||
"transcription",
|
||||
"draft",
|
||||
]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral"]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal[
|
||||
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
|
||||
@ -141,6 +141,7 @@ class ModelConfig:
|
||||
- "hf" will use the fast tokenizer if available.\n
|
||||
- "slow" will always use the slow tokenizer.\n
|
||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
||||
- Other custom values can be supported via plugins."""
|
||||
trust_remote_code: bool = False
|
||||
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
||||
@ -1779,20 +1780,22 @@ class ModelConfig:
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support chunked prefill.",
|
||||
pooling_type,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# pooling_type == "last"
|
||||
elif pooling_type in ["all", "last"]:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and last pooling support "
|
||||
"chunked prefill."
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"chunked prefill.",
|
||||
pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return attn_type != "encoder_decoder"
|
||||
@ -1816,20 +1819,22 @@ class ModelConfig:
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support prefix caching.",
|
||||
pooling_type,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# pooling_type == "last"
|
||||
elif pooling_type in ["all", "last"]:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and last pooling support "
|
||||
"prefix caching."
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"prefix caching.",
|
||||
pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return False
|
||||
|
||||
@ -593,10 +593,14 @@ class ParallelConfig:
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
|
||||
allowed_backends = ("mp", "uni", "external_launcher")
|
||||
if (
|
||||
self.distributed_executor_backend not in allowed_backends
|
||||
and self.nnodes > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed executor "
|
||||
"backend is mp or uni."
|
||||
"backend is mp, uni or external_launcher."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
|
||||
"DecodeBenchConnector",
|
||||
)
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
|
||||
"MooncakeConnector",
|
||||
)
|
||||
|
||||
@ -4,13 +4,14 @@
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
@ -21,89 +22,6 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class model_aware_kv_ops_helper:
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
def get_model_args(self, model_executable: torch.nn.Module):
|
||||
model_config = model_executable.model.config
|
||||
self.model_executable = model_executable
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/v1/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim", None)
|
||||
if head_size is None:
|
||||
head_size = int(hidden_size // num_attention_heads)
|
||||
|
||||
return num_heads, head_size
|
||||
|
||||
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
return key_cache, value_cache
|
||||
|
||||
def put_kv_to_cache(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
keys,
|
||||
values,
|
||||
layer,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
start_pos,
|
||||
end_pos,
|
||||
):
|
||||
model_config = model_executable.model.config
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
k_c_normed_k_pe = keys.squeeze(1)
|
||||
k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank]
|
||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :]
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed.to(kv_cache.device),
|
||||
k_pe.to(kv_cache.device),
|
||||
kv_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys.to(key_cache.device),
|
||||
values.to(value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
# used for faster transfer.
|
||||
@ -266,3 +184,124 @@ def copy_kv_blocks(
|
||||
src_tensor = src_kv_caches[layer_name]
|
||||
dst_tensor = dst_kv_caches[layer_name]
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[str, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: str
|
||||
remote_block_size: dict[str, int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> float:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: str) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
|
||||
@ -0,0 +1,914 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run VLLM with MooncakeTransferEngine."
|
||||
) from e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
TRANS_DONE = b"trans_done"
|
||||
TRANS_ERROR = b"trans_error"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
):
|
||||
remote_hostname: str
|
||||
remote_port: int
|
||||
request_ids: list[ReqId]
|
||||
kv_caches_base_addr: list[int]
|
||||
block_ids: list[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecvReqMeta:
|
||||
local_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendBlockMeta:
|
||||
local_block_ids: list[int]
|
||||
ready: threading.Event
|
||||
expire_time: float = float("inf")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendReqMeta:
|
||||
reqs: dict[ReqId, SendBlockMeta]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedSendReqSet:
|
||||
set: set[ReqId]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedReceiveReqSet:
|
||||
set: set[ReqId]
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, list[int]] = {}
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
):
|
||||
if load_remote_cache:
|
||||
self.reqs_to_recv[request_id] = RecvReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
)
|
||||
else:
|
||||
self.reqs_to_send[request_id] = local_block_ids
|
||||
|
||||
|
||||
class MooncakeConnector(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler: MooncakeConnectorScheduler | None = (
|
||||
MooncakeConnectorScheduler(vllm_config, self.engine_id)
|
||||
)
|
||||
self.connector_worker: MooncakeConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens
|
||||
)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""MooncakeConnector does not do layerwise saving."""
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""MooncakeConnector does not save explicitly."""
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
class MooncakeConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.side_channel_host = get_ip()
|
||||
self.side_channel_port = get_mooncake_side_channel_port(vllm_config)
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, list[int]] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector get_num_new_matched_tokens: "
|
||||
"num_computed_tokens=%s, kv_transfer_params=%s",
|
||||
num_computed_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
token_ids = request.prompt_token_ids or []
|
||||
count = len(token_ids) - num_computed_tokens
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector update_state_after_alloc: "
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens,
|
||||
params,
|
||||
)
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
assert self.kv_role != "kv_producer"
|
||||
if all(p in params for p in ("remote_host", "remote_port")):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
elif params.get("do_remote_decode"):
|
||||
# Add an empty list to worker to create event.
|
||||
self._reqs_need_send[request.request_id] = []
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
meta = MooncakeConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to RecvReqMeta.
|
||||
if self.kv_role != "kv_producer":
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
for req_id, block_ids in self._reqs_need_send.items():
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params={},
|
||||
load_remote_cache=False,
|
||||
)
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"MooncakeConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s",
|
||||
request.status,
|
||||
params,
|
||||
)
|
||||
if not params:
|
||||
return False, None
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
# If do_remote_prefill is still True when the request is finished,
|
||||
# update_state_after_alloc must not have been called (the request
|
||||
# must have been aborted before it was scheduled).
|
||||
# To avoid stranding the prefill blocks in the prefill instance,
|
||||
# we must add empty block_ids to _reqs_need_recv so that our
|
||||
# worker side will notify and free blocks in the prefill instance.
|
||||
assert self.kv_role != "kv_producer"
|
||||
self._reqs_need_recv[request.request_id] = (request, [])
|
||||
params["do_remote_prefill"] = False
|
||||
return False, None
|
||||
|
||||
if (
|
||||
not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
):
|
||||
return False, None
|
||||
|
||||
assert self.kv_role != "kv_consumer"
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
# remove the conditional below
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
|
||||
if delay_free_blocks:
|
||||
self._reqs_need_send[request.request_id] = block_ids
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
|
||||
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = get_ip()
|
||||
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||
|
||||
self.rpc_port = self.engine.get_rpc_port()
|
||||
|
||||
logger.debug(
|
||||
"Mooncake Transfer Engine initialized at %s:%d",
|
||||
self.hostname,
|
||||
self.rpc_port,
|
||||
)
|
||||
|
||||
# Mooncake handshake port.
|
||||
self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config)
|
||||
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
self.num_blocks = 0
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"num_workers", 10
|
||||
)
|
||||
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
|
||||
|
||||
# For kv_both, we will act both prefiller and decoder.
|
||||
if self.kv_role != "kv_consumer":
|
||||
# Background thread for sending kvcaches to D.
|
||||
self._mooncake_sender_t: threading.Thread | None = None
|
||||
# Background thread for processing new sending requests.
|
||||
self._sender_executor = ThreadPoolExecutor(
|
||||
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
|
||||
)
|
||||
logger.debug(
|
||||
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
|
||||
)
|
||||
if self.kv_role != "kv_producer":
|
||||
self.receiver_loop = asyncio.new_event_loop()
|
||||
self._mooncake_receiver_t = threading.Thread(
|
||||
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
|
||||
)
|
||||
self._mooncake_receiver_t.start()
|
||||
logger.debug("Mooncake Decoder: start receiver thread")
|
||||
|
||||
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
|
||||
set(), threading.Lock()
|
||||
)
|
||||
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
|
||||
set(), asyncio.Lock()
|
||||
)
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
self.backend_name = backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
||||
self.kv_topo = TpKVTopology(
|
||||
tp_rank=self.tp_rank,
|
||||
engine_id=self.engine_id,
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
remote_block_size=self._block_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
|
||||
self.zmq_ctx = zmq.Context()
|
||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||
self._encoder = msgspec.msgpack.Encoder()
|
||||
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self.zmq_ctx.term()
|
||||
self.async_zmq_ctx.term()
|
||||
if self.kv_role != "kv_consumer":
|
||||
self._sender_executor.shutdown(wait=False)
|
||||
if self._mooncake_sender_t:
|
||||
self._mooncake_sender_t.join()
|
||||
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
|
||||
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
||||
self._mooncake_receiver_t.join()
|
||||
|
||||
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
def _mooncake_sender(
|
||||
self, ready_event: threading.Event, base_port: int, tp_rank: int
|
||||
):
|
||||
"""
|
||||
Background thread that listens for Mooncake requests, dispatches them
|
||||
to a thread pool, and sends acknowledgments upon completion.
|
||||
"""
|
||||
|
||||
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
||||
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
|
||||
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
|
||||
|
||||
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
|
||||
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(frontend, zmq.POLLIN)
|
||||
poller.register(backend, zmq.POLLIN)
|
||||
|
||||
ready_event.set()
|
||||
|
||||
try:
|
||||
while True:
|
||||
sockets = dict(poller.poll())
|
||||
|
||||
if frontend in sockets:
|
||||
identity, _, metadata_bytes = frontend.recv_multipart()
|
||||
self._sender_executor.submit(
|
||||
self._sender_worker,
|
||||
identity,
|
||||
metadata_bytes,
|
||||
backend_path,
|
||||
)
|
||||
|
||||
if backend in sockets:
|
||||
identity, status = backend.recv_multipart()
|
||||
frontend.send_multipart((identity, b"", status))
|
||||
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
||||
except Exception as e:
|
||||
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
||||
finally:
|
||||
frontend.close()
|
||||
backend.close()
|
||||
|
||||
def _sender_worker(
|
||||
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
|
||||
):
|
||||
status = TRANS_ERROR
|
||||
|
||||
try:
|
||||
metadata = self._decoder.decode(metadata_bytes)
|
||||
self.send_kv_to_decode(metadata)
|
||||
status = TRANS_DONE
|
||||
except Exception as e:
|
||||
logger.error("Error processing Mooncake handshake: %s", e)
|
||||
finally:
|
||||
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
|
||||
try:
|
||||
pusher.send_multipart((identity, status))
|
||||
except zmq.ZMQError as e:
|
||||
logger.warning(
|
||||
"Internal error, maybe the server is shutting down. Error: %s",
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
pusher.close()
|
||||
|
||||
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
send_meta = self.reqs_need_send.reqs.get(req_id)
|
||||
if send_meta is None:
|
||||
logger.warning("Request %s not found in reqs_need_send", req_id)
|
||||
return
|
||||
# Mark it as not expired. We will send it now.
|
||||
send_meta.expire_time = float("inf")
|
||||
send_reqs.append((req_id, send_meta))
|
||||
|
||||
self._send_blocks(send_reqs, meta)
|
||||
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
|
||||
with self.finished_sending_reqs.lock:
|
||||
self.finished_sending_reqs.set.update(meta.request_ids)
|
||||
|
||||
def _send_blocks(
|
||||
self,
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]],
|
||||
agent_meta: MooncakeAgentMetadata,
|
||||
):
|
||||
src_ptrs = []
|
||||
dst_ptrs = []
|
||||
lengths = []
|
||||
local_base_addr = self.kv_caches_base_addr
|
||||
remote_base_addr = agent_meta.kv_caches_base_addr
|
||||
block_len = self.block_len
|
||||
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
|
||||
|
||||
assert len(send_reqs) == len(agent_meta.block_ids)
|
||||
for (req_id, send_meta), remote_block_ids in zip(
|
||||
send_reqs, agent_meta.block_ids
|
||||
):
|
||||
send_meta.ready.wait()
|
||||
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
if num_remote_blocks == 0:
|
||||
continue
|
||||
|
||||
local_block_ids = send_meta.local_block_ids
|
||||
# Partial prefix cache hit: just read uncomputed blocks.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
assert num_local_blocks >= num_remote_blocks
|
||||
if num_local_blocks > num_remote_blocks:
|
||||
local_block_ids = local_block_ids[-num_remote_blocks:]
|
||||
|
||||
# Group by indices
|
||||
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids
|
||||
)
|
||||
|
||||
for local_layer_addr, remote_layer_addr in zip(
|
||||
local_base_addr, remote_base_addr
|
||||
):
|
||||
for group_local_block_id, group_remote_block_id in zip(
|
||||
group_local_block_ids, group_remote_block_ids
|
||||
):
|
||||
src_ptrs.append(
|
||||
local_layer_addr + group_local_block_id[0] * block_len
|
||||
)
|
||||
dst_ptrs.append(
|
||||
remote_layer_addr + group_remote_block_id[0] * block_len
|
||||
)
|
||||
lengths.append(block_len * len(group_local_block_id))
|
||||
|
||||
logger.debug(
|
||||
"Sending kv_caches for request %s (%d blocks) to %s",
|
||||
req_id,
|
||||
num_remote_blocks,
|
||||
remote_session,
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
ret_value = self.engine.batch_transfer_sync_write(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
||||
|
||||
logger.debug(
|
||||
"Sending to %s done, took %s",
|
||||
remote_session,
|
||||
time.perf_counter() - start_time,
|
||||
)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in mooncake."""
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
|
||||
|
||||
kv_data_ptrs = []
|
||||
kv_data_lens = []
|
||||
seen_base_addresses = []
|
||||
|
||||
split_k_and_v = self.kv_topo.split_k_and_v
|
||||
tensor_size_bytes = None
|
||||
for layer_name, cache_or_caches in kv_caches.items():
|
||||
logger.debug(
|
||||
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
|
||||
)
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.nbytes
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
||||
assert self.block_size == kernel_block_size
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(tensor_size_bytes)
|
||||
|
||||
self.kv_caches_base_addr = seen_base_addresses
|
||||
|
||||
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake batch memory registration failed.")
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.device_kv_caches = kv_caches
|
||||
logger.debug(
|
||||
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
||||
)
|
||||
|
||||
# No need to launch server for D node.
|
||||
if self.kv_role == "kv_consumer":
|
||||
return
|
||||
|
||||
ready_event = threading.Event()
|
||||
self._mooncake_sender_t = threading.Thread(
|
||||
target=self._mooncake_sender,
|
||||
args=(ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="mooncake_sender",
|
||||
)
|
||||
self._mooncake_sender_t.start()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||
async with self.finished_recving_reqs.lock:
|
||||
finished_recving_reqs = self.finished_recving_reqs.set
|
||||
self.finished_recving_reqs.set = set()
|
||||
return finished_recving_reqs
|
||||
|
||||
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
to track which workers are done.
|
||||
"""
|
||||
fut = None
|
||||
if self.kv_role != "kv_producer":
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self.fetch_finished_recving_reqs(), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.finished_sending_reqs.lock:
|
||||
finished_sending_reqs = self.finished_sending_reqs.set
|
||||
self.finished_sending_reqs.set = set()
|
||||
else:
|
||||
finished_sending_reqs = set()
|
||||
|
||||
finished_recving_reqs = fut.result() if fut else set()
|
||||
|
||||
if finished_sending_reqs or finished_recving_reqs:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
"and %s requests done recving",
|
||||
self.tp_rank,
|
||||
len(finished_sending_reqs),
|
||||
len(finished_recving_reqs),
|
||||
)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
with self.reqs_need_send.lock:
|
||||
expired_reqs = [
|
||||
req_id
|
||||
for req_id, send_meta in self.reqs_need_send.reqs.items()
|
||||
if send_meta.expire_time < now
|
||||
]
|
||||
for req_id in expired_reqs:
|
||||
logger.warning(
|
||||
"Request %s timed out after %d seconds without "
|
||||
"being sent. Freeing its blocks on the producer side.",
|
||||
req_id,
|
||||
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
if expired_reqs:
|
||||
finished_sending_reqs.update(expired_reqs)
|
||||
|
||||
return finished_sending_reqs or None, finished_recving_reqs or None
|
||||
|
||||
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
|
||||
req_ids, block_ids = map(list, zip(*req_blocks))
|
||||
metadata = MooncakeAgentMetadata(
|
||||
remote_hostname=self.hostname,
|
||||
remote_port=self.rpc_port,
|
||||
request_ids=req_ids,
|
||||
kv_caches_base_addr=self.kv_caches_base_addr,
|
||||
block_ids=block_ids,
|
||||
)
|
||||
|
||||
encoded_data = self._encoder.encode(metadata)
|
||||
logger.debug(
|
||||
"Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data)
|
||||
)
|
||||
logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path)
|
||||
|
||||
# Send query for the request.
|
||||
sock: zmq.asyncio.Socket = make_zmq_socket(
|
||||
self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0
|
||||
)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 60000)
|
||||
try:
|
||||
await sock.send(encoded_data)
|
||||
ret_msg = await sock.recv()
|
||||
if ret_msg != TRANS_DONE:
|
||||
logger.error(
|
||||
"Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501
|
||||
req_ids,
|
||||
)
|
||||
return
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
|
||||
except Exception as e:
|
||||
logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e)
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async with self.finished_recving_reqs.lock:
|
||||
self.finished_recving_reqs.set.update(req_ids)
|
||||
|
||||
logger.debug("pulling kv_caches for %s finished", req_ids)
|
||||
|
||||
def group_kv_pull(self, metadata: MooncakeConnectorMetadata):
|
||||
kv_pulls = defaultdict(list)
|
||||
for req_id, meta in metadata.reqs_to_recv.items():
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine. "
|
||||
"Num local_block_ids: %s.",
|
||||
req_id,
|
||||
len(meta.local_block_ids),
|
||||
)
|
||||
path = make_zmq_path(
|
||||
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
|
||||
)
|
||||
kv_pulls[path].append((req_id, meta.local_block_ids))
|
||||
|
||||
return kv_pulls
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
if self.kv_role != "kv_producer":
|
||||
kv_pulls = self.group_kv_pull(metadata)
|
||||
for path, req_blocks in kv_pulls.items():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.receive_kv(path, req_blocks), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id, block_ids in metadata.reqs_to_send.items():
|
||||
if block_ids:
|
||||
# Already gone through request_finished()
|
||||
send_meta = self.reqs_need_send.reqs[req_id]
|
||||
send_meta.local_block_ids = block_ids
|
||||
send_meta.ready.set()
|
||||
send_meta.expire_time = (
|
||||
time.perf_counter()
|
||||
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
else:
|
||||
# From update_state_after_alloc(),
|
||||
# but not reach request_finished() yet
|
||||
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
|
||||
local_block_ids=[], ready=threading.Event()
|
||||
)
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: list[int], dst_indices: list[int]
|
||||
) -> tuple[list[list[int]], list[list[int]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if len(src_indices) == 0:
|
||||
return [], []
|
||||
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
|
||||
return src_groups, dst_groups
|
||||
|
||||
|
||||
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
||||
# This logic is now centralized
|
||||
return (
|
||||
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
@ -20,10 +20,10 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
@ -668,128 +668,6 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
Helper class for tensor parallel and KV topology information for
|
||||
mapping between local and remote TP workers.
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: EngineId
|
||||
remote_block_size: dict[EngineId, int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (
|
||||
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
|
||||
)
|
||||
|
||||
@property
|
||||
def tp_size(self) -> int:
|
||||
return self.remote_tp_size[self.engine_id]
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self.remote_block_size[self.engine_id]
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> float:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
@ -958,7 +836,7 @@ class NixlConnectorWorker:
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
self.kv_topo = self.TpKVTopology(
|
||||
self.kv_topo = TpKVTopology(
|
||||
tp_rank=self.tp_rank,
|
||||
engine_id=self.engine_id,
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
|
||||
@ -1169,17 +1169,13 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
if config is not None and config.parallel_config.nnodes > 1:
|
||||
parallel_config = config.parallel_config
|
||||
ip = parallel_config.master_addr
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
port = parallel_config.master_port
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
elif (
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.data_parallel_size > 1
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
and (
|
||||
config.parallel_config.nnodes > 1
|
||||
or config.parallel_config.data_parallel_size > 1
|
||||
)
|
||||
):
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
@ -1187,15 +1183,22 @@ def init_distributed_environment(
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
# adjust the world size to take into account data parallelism
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
logger.debug(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
|
||||
# Use appropriate IP and port based on configuration
|
||||
if parallel_config.nnodes > 1:
|
||||
ip = parallel_config.master_addr
|
||||
port = parallel_config.master_port
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
else:
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
logger.debug(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
if not torch.distributed.is_initialized():
|
||||
logger.info(
|
||||
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
|
||||
|
||||
@ -183,7 +183,9 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
|
||||
if anthropic_request.stream:
|
||||
req.stream = anthropic_request.stream
|
||||
req.stream_options = StreamOptions.validate({"include_usage": True})
|
||||
req.stream_options = StreamOptions.validate(
|
||||
{"include_usage": True, "continuous_usage_stats": True}
|
||||
)
|
||||
|
||||
if anthropic_request.tool_choice is None:
|
||||
req.tool_choice = None
|
||||
@ -323,6 +325,12 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
content=[],
|
||||
model=origin_chunk.model,
|
||||
),
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
else 0,
|
||||
output_tokens=0,
|
||||
),
|
||||
)
|
||||
first_item = False
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
@ -536,7 +536,7 @@ def resolve_hf_chat_template(
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
@ -627,11 +627,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T | None]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
|
||||
@ -1139,11 +1138,19 @@ def validate_chat_template(chat_template: Path | str | None):
|
||||
not any(c in chat_template for c in JINJA_CHARS)
|
||||
and not Path(chat_template).exists()
|
||||
):
|
||||
raise ValueError(
|
||||
f"The supplied chat template string ({chat_template}) "
|
||||
f"appears path-like, but doesn't exist!"
|
||||
# Try to find the template in the built-in templates directory
|
||||
from vllm.transformers_utils.chat_templates.registry import (
|
||||
CHAT_TEMPLATES_DIR,
|
||||
)
|
||||
|
||||
builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
|
||||
if not builtin_template_path.exists():
|
||||
raise ValueError(
|
||||
f"The supplied chat template string ({chat_template}) "
|
||||
f"appears path-like, but doesn't exist! "
|
||||
f"Tried: {chat_template} and {builtin_template_path}"
|
||||
)
|
||||
|
||||
else:
|
||||
raise TypeError(f"{type(chat_template)} is not a valid chat template type")
|
||||
|
||||
@ -1173,12 +1180,23 @@ def _load_chat_template(
|
||||
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template for c in JINJA_CHARS):
|
||||
msg = (
|
||||
f"The supplied chat template ({chat_template}) "
|
||||
f"looks like a file path, but it failed to be "
|
||||
f"opened. Reason: {e}"
|
||||
# Try to load from the built-in templates directory
|
||||
from vllm.transformers_utils.chat_templates.registry import (
|
||||
CHAT_TEMPLATES_DIR,
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
|
||||
try:
|
||||
with open(builtin_template_path) as f:
|
||||
return f.read()
|
||||
except OSError:
|
||||
msg = (
|
||||
f"The supplied chat template ({chat_template}) "
|
||||
f"looks like a file path, but it failed to be opened. "
|
||||
f"Tried: {chat_template} and {builtin_template_path}. "
|
||||
f"Reason: {e}"
|
||||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
@ -1593,7 +1611,6 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
@ -1601,7 +1618,7 @@ def parse_chat_messages(
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
@ -1625,7 +1642,6 @@ def parse_chat_messages(
|
||||
def parse_chat_messages_futures(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
@ -1633,7 +1649,7 @@ def parse_chat_messages_futures(
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
|
||||
@ -328,6 +328,105 @@ def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
return token_ids
|
||||
|
||||
|
||||
def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem:
|
||||
"""Parse browser tool calls (search, open, find) into web search items."""
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
|
||||
# Parse JSON args (with retry detection)
|
||||
try:
|
||||
browser_call = json.loads(content.text)
|
||||
except json.JSONDecodeError:
|
||||
json_retry_output_message = (
|
||||
f"Invalid JSON args, caught and retried: {content.text}"
|
||||
)
|
||||
browser_call = {
|
||||
"query": json_retry_output_message,
|
||||
"url": json_retry_output_message,
|
||||
"pattern": json_retry_output_message,
|
||||
}
|
||||
|
||||
# Create appropriate action based on recipient
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
||||
)
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
||||
)
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(
|
||||
pattern=browser_call.get("pattern", ""),
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
|
||||
return ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
|
||||
|
||||
def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
|
||||
"""Parse function calls into function tool call items."""
|
||||
function_name = recipient.split(".")[-1]
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_reasoning_content(message: Message) -> list[ResponseOutputItem]:
|
||||
"""Parse reasoning/analysis content into reasoning items."""
|
||||
output_items = []
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
return output_items
|
||||
|
||||
|
||||
def _parse_final_message(message: Message) -> ResponseOutputItem:
|
||||
"""Parse final channel messages into output message items."""
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
return ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
|
||||
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
"""
|
||||
Parse a Harmony message into a list of output response items.
|
||||
@ -340,119 +439,38 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
|
||||
output_items: list[ResponseOutputItem] = []
|
||||
recipient = message.recipient
|
||||
|
||||
# Browser tool calls
|
||||
if recipient is not None and recipient.startswith("browser."):
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
# We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY
|
||||
# env variable since if it is not set, we are certain the json is valid
|
||||
# The use of Actions for web search will be removed entirely in
|
||||
# the future, so this is only necessary temporarily
|
||||
try:
|
||||
browser_call = json.loads(content.text)
|
||||
except json.JSONDecodeError:
|
||||
# If the content is not valid JSON, then it was
|
||||
# caught and retried by vLLM, which means we
|
||||
# need to make note of that so the user is aware
|
||||
json_retry_output_message = (
|
||||
f"Invalid JSON args, caught and retried: {content.text}"
|
||||
)
|
||||
browser_call = {
|
||||
"query": json_retry_output_message,
|
||||
"url": json_retry_output_message,
|
||||
"pattern": json_retry_output_message,
|
||||
}
|
||||
# TODO: translate to url properly!
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search"
|
||||
)
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(
|
||||
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
|
||||
)
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(
|
||||
pattern=browser_call["pattern"],
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
web_search_item = ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
output_items.append(web_search_item)
|
||||
output_items.append(_parse_browser_tool_call(message, recipient))
|
||||
|
||||
# Analysis channel (reasoning/chain-of-thought)
|
||||
elif message.channel == "analysis":
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=content.text, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
output_items.extend(_parse_reasoning_content(message))
|
||||
|
||||
# Commentary channel
|
||||
elif message.channel == "commentary":
|
||||
# Function calls
|
||||
if recipient is not None and recipient.startswith("functions."):
|
||||
function_name = recipient.split(".")[-1]
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
output_items.extend(_parse_function_call(message, recipient))
|
||||
|
||||
# Built-in tools on commentary channel are treated as reasoning for now
|
||||
elif recipient is not None and (
|
||||
recipient.startswith("python")
|
||||
or recipient.startswith("browser")
|
||||
or recipient.startswith("container")
|
||||
):
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=content.text, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
output_items.extend(_parse_reasoning_content(message))
|
||||
else:
|
||||
raise ValueError(f"Unknown recipient: {recipient}")
|
||||
|
||||
# Final output message
|
||||
elif message.channel == "final":
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
output_items.append(text_item)
|
||||
output_items.append(_parse_final_message(message))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
|
||||
return output_items
|
||||
|
||||
|
||||
|
||||
@ -834,7 +834,6 @@ class LLM:
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
msgs,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
extract_trace_headers,
|
||||
@ -1088,11 +1088,6 @@ class OpenAIServing:
|
||||
Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt],
|
||||
]:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
@ -1105,7 +1100,6 @@ class OpenAIServing:
|
||||
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
|
||||
messages,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
@ -1128,6 +1122,13 @@ class OpenAIServing:
|
||||
messages=messages,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
elif isinstance(tokenizer, DeepseekV32Tokenizer):
|
||||
request_prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@ -30,6 +30,10 @@ _TOOL_PARSERS_TO_REGISTER = {
|
||||
"deepseekv31_tool_parser",
|
||||
"DeepSeekV31ToolParser",
|
||||
),
|
||||
"deepseek_v32": (
|
||||
"deepseekv32_tool_parser",
|
||||
"DeepSeekV32ToolParser",
|
||||
),
|
||||
"ernie45": (
|
||||
"ernie45_tool_parser",
|
||||
"Ernie45ToolParser",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user