From f081c3ce4b020fb094e33575d178345c477ab0c6 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 1 Jun 2024 14:16:07 +0530 Subject: [PATCH 01/18] [Kernel] Update Cutlass fp8 configs (#5144) Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- .../cutlass_benchmarks/w8a8_benchmarks.py | 352 ++++++++++++++++++ .../cutlass_benchmarks/weight_shapes.py | 37 ++ .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 104 +++++- tests/kernels/test_cutlass.py | 2 +- 4 files changed, 480 insertions(+), 15 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/w8a8_benchmarks.py create mode 100644 benchmarks/cutlass_benchmarks/weight_shapes.py diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py new file mode 100644 index 0000000000000..6de56f618700d --- /dev/null +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -0,0 +1,352 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +# helpers + + +def to_fp8(tensor: torch.tensor) -> torch.tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor) -> torch.tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.tensor, torch.tensor]: + + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +# impl + + +def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch.mm(a, b) + + +def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch._scaled_mm(a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype) + + +def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, + scale_a: torch.tensor, scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return torch._scaled_mm(a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + use_fast_accum=True) + + +def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, + out_dtype: torch.dtype) -> torch.tensor: + return ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=out_dtype) + + +# bench +def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, + scale_b: torch.tensor, out_dtype: torch.dtype, label: str, + sub_label: str, fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "a": a, + "b": b, + "scale_a": scale_a, + "scale_b": scale_b, + "out_dtype": out_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(a, b, scale_a, scale_b, out_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + timers = [] + # pytorch impl + timers.append( + bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, + torch.bfloat16, label, sub_label, pytorch_i8_impl, + "pytorch_bf16_bf16_bf16_matmul-no-scales")) + + # cutlass impl + timers.append( + bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), + torch.bfloat16, label, sub_label, cutlass_impl, + "cutlass_i8_i8_bf16_scaled_mm")) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + timers = [] + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + pytorch_fp8_impl_fast_accum, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, + pytorch_fp8_impl_fast_accum, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) + + # cutlass impl: bf16 output + timers.append( + bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), + torch.bfloat16, label, sub_label, cutlass_impl, + "cutlass_fp8_fp8_bf16_scaled_mm")) + # cutlass impl: fp16 output + timers.append( + bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), + torch.float16, label, sub_label, cutlass_impl, + "cutlass_fp8_fp8_fp16_scaled_mm")) + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = argparse.ArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py new file mode 100644 index 0000000000000..7ad4a53d376b6 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -0,0 +1,37 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 2383760abcdb0..4c1aec03a3caa 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -51,6 +51,11 @@ using namespace cute; namespace { +uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + template @@ -188,8 +193,89 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, cutlass::Status status = gemm_op.run(args, workspace.get(), stream); CUTLASS_CHECK(status); } + +template +struct sm90_fp8_config { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template +struct sm90_fp8_config { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template +struct sm90_fp8_config { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + } // namespace +template +void cutlass_scaled_mm_dq_sm90_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_scaled_mm_dq_dispatcher( + out, a, b, a_scales, b_scales); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_scaled_mm_dq_dispatcher( + out, a, b, a_scales, b_scales); + } else { + // m in (128, inf) + return cutlass_scaled_mm_dq_dispatcher( + out, a, b, a_scales, b_scales); + } +} + void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -223,24 +309,14 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = - typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; - using EpilogueSchedule = - typename cutlass::epilogue::TmaWarpSpecializedCooperative; - if (out.dtype() == torch::kBFloat16) { - return cutlass_scaled_mm_dq_dispatcher< - cutlass_3x_gemm>( + return cutlass_scaled_mm_dq_sm90_fp8_dispatch( out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - - return cutlass_scaled_mm_dq_dispatcher< - cutlass_3x_gemm>( + return cutlass_scaled_mm_dq_sm90_fp8_dispatch( out, a, b, a_scales, b_scales); } } diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 5a18dd5c1e3b3..079d9650c7af5 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -82,7 +82,7 @@ def cutlass_int8_gemm_helper(m: int, assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) -@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("m", [512, 222, 100, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 496, 1024]) @pytest.mark.parametrize("per_act_token", [True, False]) From c35407282878cb3a42860d584a4d9eb6aed82299 Mon Sep 17 00:00:00 2001 From: Ye Cao <952129620@qq.com> Date: Sun, 2 Jun 2024 01:11:22 +0800 Subject: [PATCH 02/18] [Minor] Fix the path typo in loader.py: save_sharded_states.py -> save_sharded_state.py (#5151) Signed-off-by: Ye Cao --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b7b5b5e7695f4..e20da0e15fb93 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -386,7 +386,7 @@ class ShardedStateLoader(BaseModelLoader): Model loader that directly loads each worker's model state dict, which enables a fast load path for large tensor-parallel models where each worker only needs to read its own shard rather than the entire checkpoint. See - `examples/save_sharded_states.py` for creating a sharded checkpoint. + `examples/save_sharded_state.py` for creating a sharded checkpoint. """ DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" From 37464a0f745a0204da7443d2a6ef4b8f65e5af12 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits <45605409+NadavShmayo@users.noreply.github.com> Date: Sat, 1 Jun 2024 20:18:50 +0300 Subject: [PATCH 03/18] [Bugfix] Fix call to init_logger in openai server (#4765) --- vllm/entrypoints/openai/api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 97b35262329ee..95417718b51fe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -36,7 +36,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding -logger = init_logger(__name__) +logger = init_logger('vllm.entrypoints.openai.api_server') _running_tasks: Set[asyncio.Task] = set() From b9c0605a8e7d558f595bd59ba6e6c95578dc0f1e Mon Sep 17 00:00:00 2001 From: chenqianfzh <51831990+chenqianfzh@users.noreply.github.com> Date: Sat, 1 Jun 2024 13:51:10 -0700 Subject: [PATCH 04/18] [Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776) --- examples/lora_with_quantization_inference.py | 140 ++++++++++ requirements-dev.txt | 3 + tests/quantization/test_bitsandbytes.py | 80 ++++++ vllm/config.py | 9 +- vllm/engine/arg_utils.py | 38 ++- vllm/model_executor/layers/linear.py | 41 ++- .../layers/quantization/__init__.py | 3 + .../layers/quantization/bitsandbytes.py | 175 +++++++++++++ vllm/model_executor/model_loader/loader.py | 247 +++++++++++++++++- .../model_loader/weight_utils.py | 16 +- vllm/model_executor/models/llama.py | 8 + 11 files changed, 752 insertions(+), 8 deletions(-) create mode 100644 examples/lora_with_quantization_inference.py create mode 100644 tests/quantization/test_bitsandbytes.py create mode 100644 vllm/model_executor/layers/quantization/bitsandbytes.py diff --git a/examples/lora_with_quantization_inference.py b/examples/lora_with_quantization_inference.py new file mode 100644 index 0000000000000..3b2347c1115e1 --- /dev/null +++ b/examples/lora_with_quantization_inference.py @@ -0,0 +1,140 @@ +""" +This example shows how to use LoRA with different quantization techniques +for offline inference. + +Requires HuggingFace credentials for access. +""" + +import gc +from typing import List, Optional, Tuple + +import torch +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest + + +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: + return [ + # this is an example of using quantization without LoRA + ("My name is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + # the next three examples use quantization with LoRA + ("my name is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), + LoRARequest("lora-test-1", 1, lora_path)), + ("The capital of USA is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), + LoRARequest("lora-test-2", 1, lora_path)), + ("The capital of France is", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), + LoRARequest("lora-test-3", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + for request_output in request_outputs: + if request_output.finished: + print("----------------------------------------------------") + print(f"Prompt: {request_output.prompt}") + print(f"Output: {request_output.outputs[0].text}") + + +def initialize_engine(model: str, quantization: str, + lora_repo: Optional[str]) -> LLMEngine: + """Initialize the LLMEngine.""" + + if quantization == "bitsandbytes": + # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique. + # It quantizes the model when loading, with some config info from the + # LoRA adapter repo. So need to set the parameter of load_format and + # qlora_adapter_name_or_path as below. + engine_args = EngineArgs( + model=model, + quantization=quantization, + qlora_adapter_name_or_path=lora_repo, + load_format="bitsandbytes", + enable_lora=True, + max_lora_rank=64, + # set it only in GPUs of limited memory + enforce_eager=True) + else: + engine_args = EngineArgs( + model=model, + quantization=quantization, + enable_lora=True, + max_loras=4, + # set it only in GPUs of limited memory + enforce_eager=True) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + + test_configs = [{ + "name": "qlora_inference_example", + 'model': "huggyllama/llama-7b", + 'quantization': "bitsandbytes", + 'lora_repo': 'timdettmers/qlora-flan-7b' + }, { + "name": "AWQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', + 'quantization': "awq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + }, { + "name": "GPTQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', + 'quantization': "gptq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + }] + + for test_config in test_configs: + print( + f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~" + ) + engine = initialize_engine(test_config['model'], + test_config['quantization'], + test_config['lora_repo']) + lora_path = snapshot_download(repo_id=test_config['lora_repo']) + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + # Clean up the GPU memory for the next test + del engine + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + main() diff --git a/requirements-dev.txt b/requirements-dev.txt index cf2bb9bef22d9..2c6b33ea813a2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -35,3 +35,6 @@ aiohttp # Multimodal pillow + +# quantization +bitsandbytes==0.42.0 diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py new file mode 100644 index 0000000000000..4e9feb3c48148 --- /dev/null +++ b/tests/quantization/test_bitsandbytes.py @@ -0,0 +1,80 @@ +'''Tests whether bitsandbytes computation is enabled correctly. + +Run `pytest tests/quantization/test_bitsandbytes.py`. +''' +import pytest +import torch + +from vllm import SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +@pytest.mark.skipif( + capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(), + reason='bitsandbytes is not supported on this GPU type.') +def test_load_bnb_model(vllm_runner) -> None: + llm = vllm_runner('huggyllama/llama-7b', + quantization='bitsandbytes', + load_format='bitsandbytes', + enforce_eager=True) + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + + # check the weights in MLP & SelfAttention are quantized to torch.uint8 + qweight = model.model.layers[0].mlp.gate_up_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].mlp.down_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].self_attn.o_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].self_attn.qkv_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') + + # some weights should not be quantized + weight = model.lm_head.weight + assert weight.dtype != torch.uint8, ( + 'lm_head weight dtype should not be torch.uint8') + + weight = model.model.embed_tokens.weight + assert weight.dtype != torch.uint8, ( + 'embed_tokens weight dtype should not be torch.uint8') + + weight = model.model.layers[0].input_layernorm.weight + assert weight.dtype != torch.uint8, ( + 'input_layernorm weight dtype should not be torch.uint8') + + weight = model.model.layers[0].post_attention_layernorm.weight + assert weight.dtype != torch.uint8, ( + 'input_layernorm weight dtype should not be torch.uint8') + + # check the output of the model is expected + sampling_params = SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=8) + + prompts = ['That which does not kill us', 'To be or not to be,'] + expected_outputs = [ + 'That which does not kill us makes us stronger.', + 'To be or not to be, that is the question.' + ] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + assert len(outputs) == len(prompts) + + for index in range(len(outputs)): + # compare the first line of the output + actual_output = outputs[index][1][0].split('\n', 1)[0] + expected_output = expected_outputs[index].split('\n', 1)[0] + assert actual_output == expected_output, ( + f'Expected: {expected_output}, but got: {actual_output}') diff --git a/vllm/config.py b/vllm/config.py index 4d05b4ea36d5c..ba4361ffb98b4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -241,6 +241,12 @@ class ModelConfig: "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") + if self.quantization == "bitsandbytes" and ( + parallel_config.tensor_parallel_size > 1 + or parallel_config.pipeline_parallel_size > 1): + raise ValueError( + "BitAndBytes quantization with TP or PP is not supported yet.") + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ @@ -327,7 +333,7 @@ class ModelConfig: def get_num_attention_heads(self, parallel_config: "ParallelConfig") -> int: return self.hf_text_config.num_attention_heads // \ - parallel_config.tensor_parallel_size + parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers @@ -487,6 +493,7 @@ class LoadFormat(str, enum.Enum): DUMMY = "dummy" TENSORIZER = "tensorizer" SHARDED_STATE = "sharded_state" + BITSANDBYTES = "bitsandbytes" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 11485aa2438c0..8a73fc931a95a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -92,6 +92,8 @@ class EngineArgs: ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None + qlora_adapter_name_or_path: Optional[str] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -159,7 +161,8 @@ class EngineArgs: type=str, default=EngineArgs.load_format, choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' ], help='The format of the model weights to load.\n\n' '* "auto" will try to load the weights in the safetensors format ' @@ -173,7 +176,9 @@ class EngineArgs: 'which is mainly for profiling.\n' '* "tensorizer" will load the weights using tensorizer from ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n') + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') parser.add_argument( '--dtype', type=str, @@ -543,7 +548,10 @@ class EngineArgs: "will also be used in `model_name` tag content of " "prometheus metrics, if multiple names provided, metrics" "tag will take the first one.") - + parser.add_argument('--qlora-adapter-name-or-path', + type=str, + default=None, + help='Name or path of the QLoRA adapter.') return parser @classmethod @@ -555,6 +563,23 @@ class EngineArgs: return engine_args def create_engine_config(self, ) -> EngineConfig: + + # bitsandbytes quantization needs a specific model loader + # so we make sure the quant method and the load format are consistent + if (self.quantization == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.load_format != "bitsandbytes": + raise ValueError( + "BitsAndBytes quantization and QLoRA adapter only support " + f"'bitsandbytes' load format, but got {self.load_format}") + + if (self.load_format == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": + raise ValueError( + "BitsAndBytes load format and QLoRA adapter only support " + f"'bitsandbytes' quantization, but got {self.quantization}") + device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -622,6 +647,13 @@ class EngineArgs: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + if self.qlora_adapter_name_or_path is not None and \ + self.qlora_adapter_name_or_path != "": + if self.model_loader_extra_config is None: + self.model_loader_extra_config = {} + self.model_loader_extra_config[ + "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 34fbfa8e33ef9..f5b6bdd9f7fd7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -26,6 +26,21 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def adjust_bitsandbytes_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -37,7 +52,7 @@ class LinearMethodBase(QuantizeMethodBase): **extra_weight_attrs): """Create weights for a linear layer. The weights will be set as attributes of the layer. - + Args: layer: The layer that is using the LinearMethodBase factory. input_size_per_partition: Size of the weight input dim on rank X. @@ -416,6 +431,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size @@ -615,6 +636,22 @@ class QKVParallelLinear(ColumnParallelLinear): shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_qkv_offsets, loaded_shard_id) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 7b9abe1b629a1..0bc42beb66257 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.bitsandbytes import ( + BitsAndBytesConfig) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( @@ -30,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "sparseml": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, } diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py new file mode 100644 index 0000000000000..969958d9b5448 --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -0,0 +1,175 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class BitsAndBytesConfig(QuantizationConfig): + """Config class for BitsAndBytes Quantization. + + Reference: https://arxiv.org/abs/2305.14314 + """ + + def __init__( + self, + adapter_name_or_path: str, + target_modules: List[str], + ) -> None: + + self.adapter_name_or_path = adapter_name_or_path + self.target_modules = target_modules + + def __repr__(self) -> str: + return ( + f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}" + ) + + @classmethod + def get_name(self) -> str: + return "bitsandbytes" + + @classmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float32, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(self) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "adapter_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": + adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"]) + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + if adapter_name == "": + target_modules = default_target_modules + else: + target_modules = cls.get_from_keys(config, ["target_modules"]) + return cls(adapter_name, target_modules) + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]: + if isinstance(layer, LinearBase): + return BitsAndBytesLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + +class BitsAndBytesLinearMethod(LinearMethodBase): + """Linear method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + quant_ratio = 0 + if params_dtype.is_floating_point: + quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo( + torch.uint8).bits + else: + quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo( + torch.uint8).bits + + if input_size_per_partition * sum( + output_partition_sizes) % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. ") + qweight = Parameter( + torch.empty( + input_size_per_partition * sum(output_partition_sizes) // + quant_ratio, + 1, + dtype=torch.uint8, + ), + requires_grad=False, + ) + + set_weight_attrs( + qweight, + { + "input_dim": 0, + # In bitsandbytes, a tensor of shape [n,m] is quantized to + #[n*m/pack_ratio, 1],so the output_dim is 0 + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes": True, + }) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + + original_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.bfloat16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + # It is more efficient to use out kwarg like + # matmul_4bit(..., out = ...). Infeasible now due to the bug + # https://github.com/TimDettmers/bitsandbytes/issues/1235. + # Need to change after the bug is fixed. + out[:, current_index:current_index + output_size] = matmul_4bit( + bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + + current_index += output_size + + out = out.to(original_type) + + if bias is not None: + out += bias + + return out diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index e20da0e15fb93..9c2eaee2eda55 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,13 +1,18 @@ # ruff: noqa: SIM117 import collections import copy +import fnmatch import glob +import json +import math import os from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Tuple, Type import huggingface_hub +import numpy as np import torch +from huggingface_hub import HfApi, hf_hub_download from torch import nn from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, @@ -28,6 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import ( get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.vlm_base import VisionLanguageModelBase +from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -125,7 +131,7 @@ class DefaultModelLoader(BaseModelLoader): def _maybe_download_from_modelscope( self, model: str, revision: Optional[str]) -> Optional[str]: """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. - + Returns the path to the downloaded model, or None if the model is not downloaded from ModelScope.""" if VLLM_USE_MODELSCOPE: @@ -247,6 +253,7 @@ class DefaultModelLoader(BaseModelLoader): model, "fall_back_to_pt_during_load", True)), ) + for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: @@ -539,6 +546,241 @@ class ShardedStateLoader(BaseModelLoader): ) +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + self.target_modules = self.default_target_modules + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download(repo_id=qlora_adapter, + filename=file) + break + + if not config_file_path: + raise ValueError( + f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, self.load_config.download_dir, + [pattern], revision) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _get_quantized_weights_iterator( + self, model_name_or_path: str, revision: Optional[str] + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + from bitsandbytes.functional import quantize_4bit + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict = {} + if use_safetensors: + weight_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weight_iterator = pt_weights_iterator(hf_weights_files) + + def generator(): + for weight_name, weight_tensor in weight_iterator: + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + return generator(), quant_state_dict + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, 'load_weights'): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(self).__name__}.") + + if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + raise AttributeError( + f"Model {type(self).__name__} does not support BitsAndBytes " + "quantization yet.") + + logger.info("Loading weights with BitsAndBytes quantization. " + " May take a while ...") + + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator(model_config.model, + model_config.revision)) + + model.load_weights(qweight_iterator) + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, index + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace( + shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model.") + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in enumerate(quant_states.items()): + num_elements[seq] = math.prod( + quant_state[1].shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + + self._load_weights(model_config, model) + + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -554,4 +796,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.SHARDED_STATE: return ShardedStateLoader(load_config) + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 53e21eba8fae3..6174f0a974712 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -130,7 +130,17 @@ def get_quant_config(model_config: ModelConfig, if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) - model_name_or_path = model_config.model + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. @@ -169,6 +179,10 @@ def get_quant_config(model_config: ModelConfig, quant_config_file = quant_config_files[0] with open(quant_config_file, "r") as f: config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + return quant_cls.from_config(config) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2ca55f9270fc7..d83ee9a201c0b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -319,6 +319,14 @@ class LlamaForCausalLM(nn.Module): "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } def __init__( self, From 8279078e218833b357f7c5076850e3688714d570 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 1 Jun 2024 15:40:25 -0700 Subject: [PATCH 05/18] [Bugfix] Remove deprecated @abstractproperty (#5174) --- vllm/core/evictor_v1.py | 5 +++-- vllm/core/evictor_v2.py | 5 +++-- vllm/lora/worker_manager.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/core/evictor_v1.py b/vllm/core/evictor_v1.py index aa51dd6938872..5db5a08a5bb67 100644 --- a/vllm/core/evictor_v1.py +++ b/vllm/core/evictor_v1.py @@ -1,5 +1,5 @@ import enum -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import OrderedDict from vllm.block import PhysicalTokenBlock @@ -44,7 +44,8 @@ class Evictor(ABC): """ pass - @abstractproperty + @property + @abstractmethod def num_blocks(self) -> int: pass diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py index 57759b29347f4..3dd12e2e25131 100644 --- a/vllm/core/evictor_v2.py +++ b/vllm/core/evictor_v2.py @@ -1,5 +1,5 @@ import enum -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import OrderedDict, Tuple @@ -46,7 +46,8 @@ class Evictor(ABC): """Remove a given block id from the cache.""" pass - @abstractproperty + @property + @abstractmethod def num_blocks(self) -> int: pass diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index d67ce67172e30..4657757bd484b 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Set, Type, Union @@ -42,7 +42,8 @@ class AbstractWorkerLoRAManager(ABC): yield self._cached_dummy_lora = False - @abstractproperty + @property + @abstractmethod def is_enabled(self) -> bool: ... From c2d6d2f960176491e0499656409f30b947ee8027 Mon Sep 17 00:00:00 2001 From: Daniil Arapov <59310708+Delviet@users.noreply.github.com> Date: Sun, 2 Jun 2024 01:53:52 +0300 Subject: [PATCH 06/18] [Bugfix]: Fix issues related to prefix caching example (#5177) (#5180) --- examples/offline_inference_with_prefix.py | 47 ++++++++++++++++++----- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index 7ed0563f14e0e..166e98549b536 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -1,5 +1,8 @@ +from time import time + from vllm import LLM, SamplingParams +# Common prefix. prefix = ( "You are an expert school principal, skilled in effectively managing " "faculty and staff. Draft 10-15 questions for a potential first grade " @@ -18,36 +21,60 @@ prompts = [ "The capital of France is", "The future of AI is", ] + +generating_prompts = [prefix + prompt for prompt in prompts] + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0) # Create an LLM. -llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True) +regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4) -generating_prompts = [prefix + prompt for prompt in prompts] +prefix_cached_llm = LLM(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.4) +print("Results without `enable_prefix_caching`") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. -outputs = llm.generate(generating_prompts, sampling_params) +start_time_regular = time() +outputs = regular_llm.generate(generating_prompts, sampling_params) +duration_regular = time() - start_time_regular + +regular_generated_texts = [] # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + regular_generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print("-" * 80) # The llm.generate call will batch all prompts and send the batch at once -# if resources allow. The prefix will only be cached after the first batch -# is processed, so we need to call generate once to calculate the prefix -# and cache it. -outputs = llm.generate(generating_prompts[0], sampling_params) +# if resources allow. +start_time_cached = time() +outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) +duration_cached = time() - start_time_cached -# Subsequent batches can leverage the cached prefix -outputs = llm.generate(generating_prompts, sampling_params) +print("Results with `enable_prefix_caching`") -# Print the outputs. You should see the same outputs as before +cached_generated_texts = [] +# Print the outputs. You should see the same outputs as before. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + cached_generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +print("-" * 80) + +# Compare the results and display the speedup +generated_same = all([ + regular_generated_texts[i] == cached_generated_texts[i] + for i in range(len(prompts)) +]) +print(f"Generated answers are the same: {generated_same}") + +speedup = round(duration_regular / duration_cached, 2) +print(f"Speed up of cached generation compared to the regular is: {speedup}") From 044793d8df6aeb5326b5992d0e60aa4457760e8a Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 1 Jun 2024 19:35:41 -0400 Subject: [PATCH 07/18] [BugFix] Prevent `LLM.encode` for non-generation Models (#5184) Co-authored-by: mgoin --- vllm/entrypoints/llm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6e971ae73f5d0..beee16d188eb5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -276,6 +276,11 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ + if self.llm_engine.model_config.embedding_mode: + raise ValueError( + "LLM.generate() is only supported for generation models " + "(XForCausalLM).") + if prompt_token_ids is not None or multi_modal_data is not None: inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), @@ -420,6 +425,11 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ + if not self.llm_engine.model_config.embedding_mode: + raise ValueError( + "LLM.encode() is only supported for embedding models (XModel)." + ) + if prompt_token_ids is not None or multi_modal_data is not None: inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), From ed59a7ed23c6e91096ea82b03037e40b14b5375c Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sat, 1 Jun 2024 21:21:53 -0500 Subject: [PATCH 08/18] Update test_ignore_eos (#4898) --- tests/samplers/test_ignore_eos.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index 864657a3c2b28..67b5168bea0e6 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -7,25 +7,26 @@ import pytest from vllm import SamplingParams -MODELS = ["facebook/opt-125m"] +# We also test with llama because it has generation_config to specify EOS +# (past regression). +MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [1024]) -def test_beam_search_single_input( +@pytest.mark.parametrize("max_tokens", [512]) +def test_ignore_eos( vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, ) -> None: - example_prompts = "1 + 1 is" - vllm_model = vllm_runner(model, dtype=dtype) sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) - ignore_eos_output = vllm_model.model.generate( - example_prompts, sampling_params=sampling_params) - print(len(ignore_eos_output[0].outputs[0].token_ids)) - assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10 - assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0 + + for prompt in example_prompts: + ignore_eos_output = vllm_model.model.generate( + prompt, sampling_params=sampling_params) + output_length = len(ignore_eos_output[0].outputs[0].token_ids) + assert output_length == max_tokens From f790ad3c50f050778af1fd31170746b7c68ca2fc Mon Sep 17 00:00:00 2001 From: Avinash Raj Date: Sun, 2 Jun 2024 13:36:13 +0530 Subject: [PATCH 09/18] [Frontend][OpenAI] Support for returning max_model_len on /v1/models response (#4643) --- vllm/entrypoints/openai/protocol.py | 1 + vllm/entrypoints/openai/serving_engine.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index e380212a4d76b..bbd61a2c5dd59 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -82,6 +82,7 @@ class ModelCard(OpenAIBaseModel): owned_by: str = "vllm" root: Optional[str] = None parent: Optional[str] = None + max_model_len: Optional[int] = None permission: List[ModelPermission] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 066acdf1c019a..ae659d19c878b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -62,6 +62,7 @@ class OpenAIServing: """Show available models. Right now we only have one model.""" model_cards = [ ModelCard(id=served_model_name, + max_model_len=self.max_model_len, root=self.served_model_names[0], permission=[ModelPermission()]) for served_model_name in self.served_model_names From a66cf40b205d57ac1b5dc96b6bb6f8e813b18316 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Sun, 2 Jun 2024 16:13:26 -0500 Subject: [PATCH 10/18] [Kernel][ROCm][AMD] enable fused topk_softmax kernel for moe layer (#4927) This PR enables the fused topk_softmax kernel used in moe layer for HIP --- CMakeLists.txt | 8 ++-- Dockerfile.rocm | 1 + csrc/cuda_compat.h | 4 ++ csrc/moe/topk_softmax_kernels.cu | 27 +++++++---- setup.py | 2 +- .../layers/fused_moe/fused_moe.py | 46 ++++++++----------- 6 files changed, 45 insertions(+), 43 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f991af61d9bd..a197063f33601 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -311,6 +311,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) + # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and # there are supported target arches. @@ -320,8 +323,3 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") add_dependencies(default _punica_C) endif() endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) -endif() diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 9bfe8446a519d..e30a2aaf30209 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -108,6 +108,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && python3 setup.py install \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 5909e5eaf5e60..82e55613d915a 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -19,8 +19,12 @@ #ifndef USE_ROCM #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 8c65f40fe836a..6ba4fcdb3a3f2 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -19,15 +19,22 @@ #include #include #include +#include "../cuda_compat.h" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) namespace vllm { namespace moe { -static constexpr int WARP_SIZE = 32; - /// Aligned array type template < typename T, @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way if (other_max > max_val || (other_max == max_val && other_expert < expert)) @@ -383,7 +390,7 @@ struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; diff --git a/setup.py b/setup.py index d99fc050f6d84..f7d465b60c153 100644 --- a/setup.py +++ b/setup.py @@ -382,7 +382,7 @@ def get_requirements() -> List[str]: ext_modules = [] -if _is_cuda(): +if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) if not _is_neuron(): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bb7938b3715be..20a3c9f6f893f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,9 +8,9 @@ import torch import triton import triton.language as tl +import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -319,34 +319,26 @@ def fused_topk( M, _ = hidden_states.shape - if is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) - else: - import vllm._moe_C as moe_kernels - - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, + topk_weights = torch.empty(M, topk, - dtype=torch.int32, + dtype=torch.float32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids From dfbe60dc62409f03aa9eebc70ab2582ae64f0e1f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 3 Jun 2024 07:05:50 +0800 Subject: [PATCH 11/18] [Misc] Simplify code and fix type annotations in `conftest.py` (#5118) --- tests/conftest.py | 92 ++++++++++++++++++++++------------------------- 1 file changed, 42 insertions(+), 50 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index af04cfbbb9902..d904058dc369c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple import pytest import torch +import torch.nn.functional as F from PIL import Image from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlavaConfig, LlavaForConditionalGeneration) @@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel -from vllm.inputs import PromptInputs +from vllm.inputs import TextPrompt from vllm.logger import init_logger -from vllm.sequence import MultiModalData +from vllm.sequence import MultiModalData, SampleLogprobs logger = init_logger(__name__) @@ -188,10 +189,11 @@ class HfRunner: prompts: List[str], images: Optional[List[Image.Image]] = None, **kwargs, - ) -> List[Tuple[List[int], str]]: - outputs: List[Tuple[List[int], str]] = [] + ) -> List[Tuple[List[List[int]], List[str]]]: if images: assert len(prompts) == len(images) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] for i, prompt in enumerate(prompts): processor_kwargs: Dict[str, Any] = { "text": prompt, @@ -201,17 +203,13 @@ class HfRunner: processor_kwargs["images"] = images[i] inputs = self.processor(**processor_kwargs) - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } output_ids = self.model.generate( - **inputs, + **inputs.to("cuda"), use_cache=True, **kwargs, ) - output_str = self.tokenizer.batch_decode( + output_str = self.processor.batch_decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, @@ -224,23 +222,22 @@ class HfRunner: self, prompts: List[str], max_tokens: int, - images: Optional["torch.Tensor"] = None, + images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, images=images) - for i in range(len(outputs)): - output_ids, output_str = outputs[i] - outputs[i] = (output_ids[0], output_str[0]) - return outputs + + return [(output_ids[0], output_str[0]) + for output_ids, output_str in outputs] def generate_beam_search( self, prompts: List[str], beam_width: int, max_tokens: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[List[int]], List[str]]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, @@ -282,9 +279,7 @@ class HfRunner: if self.model.get_output_embeddings().bias is not None: logits += self.model.get_output_embeddings( ).bias.unsqueeze(0) - logprobs = torch.nn.functional.log_softmax(logits, - dim=-1, - dtype=torch.float32) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) all_logprobs.append(seq_logprobs) return all_logprobs @@ -294,10 +289,10 @@ class HfRunner: prompts: List[str], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str]]: - all_logprobs = [] - all_output_ids = [] - all_output_strs = [] + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids @@ -310,7 +305,7 @@ class HfRunner: return_dict_in_generate=True, ) - seq_logprobs = [] + seq_logprobs: List[torch.Tensor] = [] for _, hidden_states in enumerate(output.hidden_states): last_hidden_states = hidden_states[-1][0] logits = torch.matmul( @@ -321,13 +316,11 @@ class HfRunner: None) is not None: logits += self.model.get_output_embeddings( ).bias.unsqueeze(0) - logprobs = torch.nn.functional.log_softmax(logits, - dim=-1, - dtype=torch.float32) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) # convert to dict - seq_logprobs_lst = [] + seq_logprobs_lst: List[Dict[int, float]] = [] for tok_idx, tok_logprobs in enumerate(seq_logprobs): # drop prompt logprobs if tok_idx == 0: @@ -372,13 +365,13 @@ class VllmRunner: tokenizer_name: Optional[str] = None, # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. - max_model_len=1024, + max_model_len: int = 1024, dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, - swap_space=4, + swap_space: int = 4, **kwargs, ) -> None: self.model = LLM( @@ -399,32 +392,31 @@ class VllmRunner: self, prompts: List[str], sampling_params: SamplingParams, - images: Optional["torch.Tensor"] = None, - ) -> List[Tuple[List[int], str]]: + images: Optional[torch.Tensor] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: - assert len(prompts) == images.shape[0] + assert len(prompts) == len(images) - prompt_inputs: List[PromptInputs] = [] + prompt_inputs: List[TextPrompt] = [] for i, prompt in enumerate(prompts): - image = None if images is None else images[i:i + 1] - mm_data = None if image is None else MultiModalData( - type=MultiModalData.Type.IMAGE, - data=image, - ) + prompt = TextPrompt(prompt=prompt) + if images is not None: + prompt["multi_modal_data"] = MultiModalData( + type=MultiModalData.Type.IMAGE, + data=images[i:i + 1], + ) - prompt_inputs.append({ - "prompt": prompt, - "multi_modal_data": mm_data, - }) + prompt_inputs.append(prompt) req_outputs = self.model.generate(prompt_inputs, sampling_params=sampling_params) - outputs = [] + + outputs: List[Tuple[List[List[int]], List[str]]] = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids - req_sample_output_ids = [] - req_sample_output_strs = [] + req_sample_output_ids: List[List[int]] = [] + req_sample_output_strs: List[str] = [] for sample in req_output.outputs: output_str = sample.text output_ids = sample.token_ids @@ -437,12 +429,12 @@ class VllmRunner: self, prompts: List[str], sampling_params: SamplingParams, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None req_outputs = self.model.generate(prompts, sampling_params=sampling_params) - outputs = [] + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] for req_output in req_outputs: for sample in req_output.outputs: output_str = sample.text @@ -467,7 +459,7 @@ class VllmRunner: prompts: List[str], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) @@ -481,7 +473,7 @@ class VllmRunner: prompts: List[str], beam_width: int, max_tokens: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[List[int]], List[str]]]: beam_search_params = SamplingParams(n=beam_width, use_beam_search=True, temperature=0.0, From 7a64d24aad69e4d2548aa0bf528d9fe63428ab01 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 3 Jun 2024 13:56:41 +0800 Subject: [PATCH 12/18] [Core] Support image processor (#4197) --- .github/workflows/mypy.yaml | 1 + docs/source/conf.py | 14 +- .../dev/multimodal/multimodal_index.rst | 51 ++++++ docs/source/index.rst | 6 +- docs/source/models/supported_models.rst | 4 + docs/source/models/vlm.rst | 56 +++++++ examples/llava_example.py | 29 ++-- format.sh | 1 + requirements-common.txt | 1 + requirements-dev.txt | 3 - tests/conftest.py | 45 ++--- tests/models/test_llava.py | 60 ++++--- tests/multimodal/__init__.py | 0 tests/multimodal/test_processor.py | 98 +++++++++++ tests/spec_decode/e2e/conftest.py | 3 +- tests/tokenization/test_image_processor.py | 20 +++ vllm/config.py | 6 +- vllm/engine/arg_utils.py | 108 ++++++++---- vllm/entrypoints/llm.py | 25 +-- vllm/model_executor/models/llava.py | 73 +++++--- vllm/multimodal/__init__.py | 7 + vllm/multimodal/base.py | 126 ++++++++++++++ vllm/multimodal/image.py | 141 ++++++++++++++++ vllm/multimodal/registry.py | 156 ++++++++++++++++++ vllm/sequence.py | 32 +--- vllm/transformers_utils/image_processor.py | 45 +++++ vllm/worker/cpu_model_runner.py | 57 ++++--- vllm/worker/embedding_model_runner.py | 10 +- vllm/worker/model_runner.py | 120 +++++++------- 29 files changed, 1042 insertions(+), 256 deletions(-) create mode 100644 docs/source/dev/multimodal/multimodal_index.rst create mode 100644 docs/source/models/vlm.rst create mode 100644 tests/multimodal/__init__.py create mode 100644 tests/multimodal/test_processor.py create mode 100644 tests/tokenization/test_image_processor.py create mode 100644 vllm/multimodal/__init__.py create mode 100644 vllm/multimodal/base.py create mode 100644 vllm/multimodal/image.py create mode 100644 vllm/multimodal/registry.py create mode 100644 vllm/transformers_utils/image_processor.py diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index a20753d8a7702..22e6c2ef0101e 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -37,6 +37,7 @@ jobs: mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml + mypy vllm/multimodal --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml diff --git a/docs/source/conf.py b/docs/source/conf.py index cfebc2ff9bb33..f1a7013edd332 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -90,6 +90,7 @@ autodoc_mock_imports = [ "sentencepiece", "vllm.cuda_utils", "vllm._C", + "PIL", "numpy", "tqdm", "tensorizer", @@ -116,12 +117,13 @@ class MockedClassDocumenter(autodoc.ClassDocumenter): autodoc.ClassDocumenter = MockedClassDocumenter intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'typing_extensions': - ('https://typing-extensions.readthedocs.io/en/latest', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable', None), - 'psutil': ('https://psutil.readthedocs.io/en/stable', None), + "python": ("https://docs.python.org/3", None), + "typing_extensions": + ("https://typing-extensions.readthedocs.io/en/latest", None), + "pillow": ("https://pillow.readthedocs.io/en/stable", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable", None), + "psutil": ("https://psutil.readthedocs.io/en/stable", None), } autodoc_preserve_defaults = True diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst new file mode 100644 index 0000000000000..a25eceecc276b --- /dev/null +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -0,0 +1,51 @@ +Multi-Modality +============== + +.. currentmodule:: vllm.multimodal + +vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. + +:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data`` +which allows you to pass in multi-modal input alongside text and token prompts. + +By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, +you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data `, +as well as :meth:`MULTIMODAL_REGISTRY.register_input ` for each modality type to support. + +.. contents:: + :local: + :backlinks: none + +Module Contents ++++++++++++++++ + +.. automodule:: vllm.multimodal + +Registry +-------- + +.. data:: vllm.multimodal.MULTIMODAL_REGISTRY + + The global :class:`MultiModalRegistry` which is used by model runners. + +.. autoclass:: vllm.multimodal.MultiModalRegistry + :members: + :show-inheritance: + +Base Classes +------------ + +.. autoclass:: vllm.multimodal.MultiModalData + :members: + :show-inheritance: + +.. autoclass:: vllm.multimodal.MultiModalPlugin + :members: + :show-inheritance: + +Image Classes +------------- + +.. automodule:: vllm.multimodal.image + :members: + :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f18fe9ae0a73..fad3c3b05b0c0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,6 +88,7 @@ Documentation models/adding_model models/engine_args models/lora + models/vlm models/performance .. toctree:: @@ -99,17 +100,18 @@ Documentation quantization/fp8_e4m3_kvcache .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Developer Documentation dev/sampling_params dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention + dev/multimodal/multimodal_index dev/dockerfile/dockerfile .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Community community/meetups diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 82e71e61975c8..24fa83df7d751 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it. - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - ✅︎ + * - :code:`LlavaForConditionalGeneration` + - LLaVA-1.5 + - :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc. + - * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst new file mode 100644 index 0000000000000..52afda747aab8 --- /dev/null +++ b/docs/source/models/vlm.rst @@ -0,0 +1,56 @@ +.. _vlm: + +Using VLMs +========== + +This document shows you how to run and serve Vision Language Models (VLMs) using vLLM. + +Engine Arguments +---------------- + +The following :ref:`engine arguments ` are specific to VLMs: + +.. argparse:: + :module: vllm.engine.arg_utils + :func: _vlm_engine_args_parser + :prog: -m vllm.entrypoints.openai.api_server + :nodefaultconst: + +Offline Batched Inference +------------------------- + +To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine. + +.. code-block:: python + + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + image_input_type="pixel_values", + image_token_id=32000, + image_input_shape="1,3,336,336", + image_feature_size=576, + ) + +For now, we only support a single image per text prompt. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: + +* ``prompt``: The prompt should have a number of ```` tokens equal to ``image_feature_size``. +* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`. + +.. code-block:: python + + prompt = "" * 576 + ( + "\nUSER: What is the content of this image?\nASSISTANT:") + + # Load the image using PIL.Image + image = ... + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": ImagePixelData(image), + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +A code example can be found in `examples/llava_example.py `_. diff --git a/examples/llava_example.py b/examples/llava_example.py index 60250c4303fbf..980d7bf9f8a3c 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -3,33 +3,36 @@ import os import subprocess import torch +from PIL import Image from vllm import LLM -from vllm.sequence import MultiModalData +from vllm.multimodal.image import ImageFeatureData, ImagePixelData # The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`. +# You can use `.buildkite/download-images.sh` to download them -def run_llava_pixel_values(): +def run_llava_pixel_values(*, disable_image_processor: bool = False): llm = LLM( model="llava-hf/llava-1.5-7b-hf", image_input_type="pixel_values", image_token_id=32000, image_input_shape="1,3,336,336", image_feature_size=576, + disable_image_processor=disable_image_processor, ) prompt = "" * 576 + ( "\nUSER: What is the content of this image?\nASSISTANT:") - # This should be provided by another online or offline component. - image = torch.load("images/stop_sign_pixel_values.pt") + if disable_image_processor: + image = torch.load("images/stop_sign_pixel_values.pt") + else: + image = Image.open("images/stop_sign.jpg") outputs = llm.generate({ - "prompt": - prompt, - "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + "prompt": prompt, + "multi_modal_data": ImagePixelData(image), }) for o in outputs: @@ -49,15 +52,13 @@ def run_llava_image_features(): prompt = "" * 576 + ( "\nUSER: What is the content of this image?\nASSISTANT:") - # This should be provided by another online or offline component. - image = torch.load("images/stop_sign_image_features.pt") + image: torch.Tensor = torch.load("images/stop_sign_image_features.pt") outputs = llm.generate({ - "prompt": - prompt, - "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + "prompt": prompt, + "multi_modal_data": ImageFeatureData(image), }) + for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/format.sh b/format.sh index d110855f8c273..ca828457f9999 100755 --- a/format.sh +++ b/format.sh @@ -101,6 +101,7 @@ mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml +mypy vllm/multimodal --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/*.py --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml diff --git a/requirements-common.txt b/requirements-common.txt index 3ea22276f63f4..f41873570aa67 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,6 +12,7 @@ aiohttp openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. +pillow # Required for image processing prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c6b33ea813a2..12b22a61ea162 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,8 +33,5 @@ sentence-transformers # required for embedding # Benchmarking aiohttp -# Multimodal -pillow - # quantization bitsandbytes==0.42.0 diff --git a/tests/conftest.py b/tests/conftest.py index d904058dc369c..e749338e1095a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,9 @@ from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel from vllm.inputs import TextPrompt from vllm.logger import init_logger -from vllm.sequence import MultiModalData, SampleLogprobs +from vllm.multimodal import MultiModalData +from vllm.multimodal.image import ImageFeatureData, ImagePixelData +from vllm.sequence import SampleLogprobs logger = init_logger(__name__) @@ -24,6 +26,7 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] # Multi modal related +# You can use `.buildkite/download-images.sh` to download the assets _PIXEL_VALUES_FILES = [ os.path.join(_TEST_DIR, "images", filename) for filename in ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"] @@ -89,17 +92,23 @@ def hf_images() -> List[Image.Image]: @pytest.fixture() -def vllm_images(request) -> "torch.Tensor": +def vllm_images(request) -> List[MultiModalData]: vision_language_config = request.getfixturevalue("model_and_config")[1] - all_images = [] if vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): - filenames = _IMAGE_FEATURES_FILES + return [ + ImageFeatureData(torch.load(filename)) + for filename in _IMAGE_FEATURES_FILES + ] else: - filenames = _PIXEL_VALUES_FILES - for filename in filenames: - all_images.append(torch.load(filename)) - return torch.concat(all_images, dim=0) + return [ + ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES + ] + + +@pytest.fixture() +def vllm_image_tensors(request) -> List[torch.Tensor]: + return [torch.load(filename) for filename in _PIXEL_VALUES_FILES] @pytest.fixture() @@ -392,23 +401,17 @@ class VllmRunner: self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[torch.Tensor] = None, + images: Optional[List[MultiModalData]] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) - prompt_inputs: List[TextPrompt] = [] - for i, prompt in enumerate(prompts): - prompt = TextPrompt(prompt=prompt) - if images is not None: - prompt["multi_modal_data"] = MultiModalData( - type=MultiModalData.Type.IMAGE, - data=images[i:i + 1], - ) + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] + if images is not None: + for i, image in enumerate(images): + inputs[i]["multi_modal_data"] = image - prompt_inputs.append(prompt) - - req_outputs = self.model.generate(prompt_inputs, + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) outputs: List[Tuple[List[List[int]], List[str]]] = [] @@ -447,7 +450,7 @@ class VllmRunner: self, prompts: List[str], max_tokens: int, - images: Optional[torch.Tensor] = None, + images: Optional[List[MultiModalData]] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index f86cd3fa88f5d..cc0685ca9c5eb 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,7 +1,7 @@ import gc from dataclasses import fields from enum import Enum -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import pytest import torch @@ -9,36 +9,50 @@ from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig + +def iter_llava_configs(model_name: str): + image_hw_to_feature_size = { + (336, 336): 576, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) + + model_and_vl_config = [ - ("llava-hf/llava-1.5-7b-hf", - VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_feature_size=576, - image_token_id=32000, - image_input_shape=(1, 3, 336, 336))), - ("llava-hf/llava-1.5-7b-hf", - VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, - image_feature_size=576, - image_token_id=32000, - image_input_shape=(1, 576, 1024))) + *iter_llava_configs("llava-hf/llava-1.5-7b-hf"), + # Not enough memory + # *iter_llava_configs("llava-hf/llava-1.5-13b-hf"), ] -def as_dict(vision_language_config: VisionLanguageConfig) -> Dict: +def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]: """Flatten vision language config to pure args. Compatible with what llm entrypoint expects. """ result = {} - for field in fields(vision_language_config): - value = getattr(vision_language_config, field.name) + for field in fields(vlm_config): + value = getattr(vlm_config, field.name) if isinstance(value, Enum): result[field.name] = value.name.lower() elif isinstance(value, tuple): result[field.name] = ",".join([str(item) for item in value]) else: result[field.name] = value + + result["disable_image_processor"] = vlm_config.image_processor is None + return result @@ -67,18 +81,19 @@ def sanitize_vllm_output(vllm_output: Tuple[List[int], str], @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, - vllm_image_prompts, vllm_images, model_and_config: tuple, - dtype: str, max_tokens: int, worker_use_ray: bool) -> None: + vllm_image_prompts, vllm_images, model_and_config, dtype: str, + max_tokens: int, worker_use_ray: bool) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the raw images as input. - For vllm runner, we provide image tensors and corresponding + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ model_id, vision_language_config = model_and_config + hf_model = hf_runner(model_id, dtype=dtype) hf_outputs = hf_model.generate_greedy(hf_image_prompts, max_tokens, @@ -88,6 +103,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, vllm_model = vllm_runner(model_id, dtype=dtype, worker_use_ray=worker_use_ray, + enforce_eager=True, **as_dict(vision_language_config)) vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, max_tokens, @@ -105,3 +121,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] +# (Requires multiple GPUs) diff --git a/tests/multimodal/__init__.py b/tests/multimodal/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py new file mode 100644 index 0000000000000..4aeae633d07f7 --- /dev/null +++ b/tests/multimodal/test_processor.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest +from transformers import CLIPImageProcessor + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import ImagePixelData + + +@pytest.mark.parametrize("dtype", ["half", "bfloat16", "float"]) +def test_clip_image_processor(hf_images, dtype): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 33 + + hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) + assert isinstance(hf_processor, CLIPImageProcessor) + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + ) + vlm_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + ) + + for image in hf_images: + hf_result = hf_processor.preprocess( + image, + return_tensors="np", + ) + vllm_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(image), + model_config=model_config, + vlm_config=vlm_config, + ) + + assert hf_result.keys() == vllm_result.keys() + for key, hf_arr in hf_result.items(): + vllm_arr: np.ndarray = vllm_result[key].numpy() + + assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" + assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.parametrize("dtype", ["float"]) +def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 33 + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + ) + vlm_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + ) + + for image, tensor in zip(hf_images, vllm_image_tensors): + image_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(image), + model_config=model_config, + vlm_config=vlm_config, + ) + tensor_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(tensor), + model_config=model_config, + vlm_config=vlm_config, + ) + + assert image_result.keys() == tensor_result.keys() + for key, image_arr in image_result.items(): + tensor_arr: np.ndarray = tensor_result[key].numpy() + + assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" + + # The examples in PR#3042 have slightly different preprocessing from + # HuggingFace's LlavaProcessor, causing the test to fail. + # assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 7c5840baf3593..1d060e265848a 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -18,9 +18,10 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed +from vllm.multimodal import MultiModalData from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, MultiModalData +from vllm.sequence import Logprob from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, random_uuid diff --git a/tests/tokenization/test_image_processor.py b/tests/tokenization/test_image_processor.py new file mode 100644 index 0000000000000..5ba2323367414 --- /dev/null +++ b/tests/tokenization/test_image_processor.py @@ -0,0 +1,20 @@ +import pytest +from transformers.image_processing_utils import BaseImageProcessor + +from vllm.transformers_utils.image_processor import get_image_processor + +IMAGE_PROCESSOR_NAMES = [ + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-v1.6-34b-hf", +] + + +@pytest.mark.parametrize("processor_name", IMAGE_PROCESSOR_NAMES) +def test_image_processor_revision(processor_name: str): + # Assume that "main" branch always exists + image_processor = get_image_processor(processor_name, revision="main") + assert isinstance(image_processor, BaseImageProcessor) + + # Assume that "never" branch always does not exist + with pytest.raises(OSError, match='not a valid git identifier'): + get_image_processor(processor_name, revision="never") diff --git a/vllm/config.py b/vllm/config.py index ba4361ffb98b4..eee62d2683835 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1094,10 +1094,12 @@ class VisionLanguageConfig: # worst case scenario (biggest supported resolution). image_input_shape: tuple image_feature_size: int + # The image processor to load from HuggingFace + image_processor: Optional[str] + image_processor_revision: Optional[str] @classmethod - def get_image_input_enum_type( - cls, value: str) -> "VisionLanguageConfig.ImageInputType": + def get_image_input_enum_type(cls, value: str) -> ImageInputType: """Get the image input type from a string.""" try: return cls.ImageInputType[value.upper()] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8a73fc931a95a..b315d4d2ece29 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,6 +1,7 @@ import argparse import dataclasses import json +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -80,6 +81,10 @@ class EngineArgs: image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None + image_processor: Optional[str] = None + image_processor_revision: Optional[str] = None + disable_image_processor: bool = False + scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -98,6 +103,53 @@ class EngineArgs: if self.tokenizer is None: self.tokenizer = self.model + @staticmethod + def add_cli_args_for_vlm( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument('--image-input-type', + type=nullable_str, + default=None, + choices=[ + t.name.lower() + for t in VisionLanguageConfig.ImageInputType + ], + help=('The image input type passed into vLLM.')) + parser.add_argument('--image-token-id', + type=int, + default=None, + help=('Input id for image token.')) + parser.add_argument( + '--image-input-shape', + type=nullable_str, + default=None, + help=('The biggest image input shape (worst for memory footprint) ' + 'given an input type. Only used for vLLM\'s profile_run.')) + parser.add_argument( + '--image-feature-size', + type=int, + default=None, + help=('The image feature size along the context dimension.')) + parser.add_argument( + '--image-processor', + type=str, + default=EngineArgs.image_processor, + help='Name or path of the huggingface image processor to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--image-processor-revision', + type=str, + default=None, + help='Revision of the huggingface image processor version to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--disable-image-processor', + action='store_true', + help='Disables the use of image processor, even if one is defined ' + 'for the model on huggingface.') + + return parser + @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -113,7 +165,8 @@ class EngineArgs: '--tokenizer', type=nullable_str, default=EngineArgs.tokenizer, - help='Name or path of the huggingface tokenizer to use.') + help='Name or path of the huggingface tokenizer to use. ' + 'If unspecified, model name or path will be used.') parser.add_argument( '--skip-tokenizer-init', action='store_true', @@ -136,9 +189,9 @@ class EngineArgs: '--tokenizer-revision', type=nullable_str, default=None, - help='The specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') + help='Revision of the huggingface tokenizer to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') parser.add_argument( '--tokenizer-mode', type=str, @@ -445,31 +498,10 @@ class EngineArgs: default=EngineArgs.device, choices=["auto", "cuda", "neuron", "cpu"], help='Device type for vLLM execution.') + # Related to Vision-language models such as llava - parser.add_argument( - '--image-input-type', - type=nullable_str, - default=None, - choices=[ - t.name.lower() for t in VisionLanguageConfig.ImageInputType - ], - help=('The image input type passed into vLLM. ' - 'Should be one of "pixel_values" or "image_features".')) - parser.add_argument('--image-token-id', - type=int, - default=None, - help=('Input id for image token.')) - parser.add_argument( - '--image-input-shape', - type=nullable_str, - default=None, - help=('The biggest image input shape (worst for memory footprint) ' - 'given an input type. Only used for vLLM\'s profile_run.')) - parser.add_argument( - '--image-feature-size', - type=int, - default=None, - help=('The image feature size along the context dimension.')) + parser = EngineArgs.add_cli_args_for_vlm(parser) + parser.add_argument( '--scheduler-delay-factor', type=float, @@ -488,7 +520,6 @@ class EngineArgs: default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') - parser.add_argument( '--num-speculative-tokens', type=int, @@ -666,12 +697,27 @@ class EngineArgs: raise ValueError( 'Specify `image_token_id`, `image_input_shape` and ' '`image_feature_size` together with `image_input_type`.') + + if self.image_processor is None: + self.image_processor = self.model + if self.disable_image_processor: + if self.image_processor != self.model: + warnings.warn( + "You've specified an image processor " + f"({self.image_processor}) but also disabled " + "it via `--disable-image-processor`.", + stacklevel=2) + + self.image_processor = None + vision_language_config = VisionLanguageConfig( image_input_type=VisionLanguageConfig. get_image_input_enum_type(self.image_input_type), image_token_id=self.image_token_id, image_input_shape=str_to_int_tuple(self.image_input_shape), image_feature_size=self.image_feature_size, + image_processor=self.image_processor, + image_processor_revision=self.image_processor_revision, ) else: vision_language_config = None @@ -734,3 +780,7 @@ def _engine_args_parser(): def _async_engine_args_parser(): return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(), async_args_only=True) + + +def _vlm_engine_args_parser(): + return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser()) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index beee16d188eb5..d4a4c16f2a7d5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -14,7 +14,6 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs @@ -164,7 +163,6 @@ class LLM: prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -177,7 +175,6 @@ class LLM: prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -191,7 +188,6 @@ class LLM: prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -205,7 +201,6 @@ class LLM: prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -217,7 +212,6 @@ class LLM: prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: ... @@ -236,7 +230,6 @@ class LLM: @deprecate_kwargs("prompts", "prompt_token_ids", - "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " "instead.") @@ -249,7 +242,6 @@ class LLM: prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -281,11 +273,10 @@ class LLM: "LLM.generate() is only supported for generation models " "(XForCausalLM).") - if prompt_token_ids is not None or multi_modal_data is not None: + if prompt_token_ids is not None: inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, ) else: inputs = cast( @@ -314,7 +305,6 @@ class LLM: prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -327,7 +317,6 @@ class LLM: prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -341,7 +330,6 @@ class LLM: prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -355,7 +343,6 @@ class LLM: prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -367,7 +354,6 @@ class LLM: prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: ... @@ -386,7 +372,6 @@ class LLM: @deprecate_kwargs("prompts", "prompt_token_ids", - "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " "instead.") @@ -399,7 +384,6 @@ class LLM: prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -430,11 +414,10 @@ class LLM: "LLM.encode() is only supported for embedding models (XModel)." ) - if prompt_token_ids is not None or multi_modal_data is not None: + if prompt_token_ids is not None: inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, ) else: inputs = cast( @@ -459,7 +442,6 @@ class LLM: self, prompts: Optional[Union[str, List[str]]], prompt_token_ids: Optional[Union[List[int], List[List[int]]]], - multi_modal_data: Optional[MultiModalData], ): # skip_tokenizer_init is now checked in engine @@ -499,9 +481,6 @@ class LLM: else: raise AssertionError - if multi_modal_data is not None: - item["multi_modal_data"] = multi_modal_data - inputs.append(item) return inputs diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index fbd7638097286..3332bcc578460 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -17,6 +17,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import get_dummy_image_data from vllm.sequence import SamplerOutput from .vlm_base import VisionLanguageModelBase @@ -82,6 +84,9 @@ class LlavaImageFeatureInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] +@MULTIMODAL_REGISTRY.register_image_feature_input() +@MULTIMODAL_REGISTRY.register_image_pixel_input() +@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, @@ -131,30 +136,41 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): return data def _parse_and_validate_image_input( - self, data: object) -> Optional[LlavaImageInputs]: + self, **kwargs: object) -> Optional[LlavaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_features = kwargs.pop("image_features", None) + expected_input_type = self.vision_language_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType - if data is None: - return None - if expected_input_type == ImageInputType.PIXEL_VALUES: - if not isinstance(data, torch.Tensor): - raise TypeError("Image pixel vector should be a tensor, " - f"but received type: {type(data)}") + if image_features is not None: + raise ValueError( + "Expected pixel values but got image features") + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values") return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_image_data(data), + data=self._validate_image_data(pixel_values), ) - elif expected_input_type == ImageInputType.IMAGE_FEATURES: - if not isinstance(data, torch.Tensor): - raise TypeError("Image feature vector should be a tensor, " - f"but received type: {type(data)}") + + if expected_input_type == ImageInputType.IMAGE_FEATURES: + if pixel_values is not None: + raise ValueError( + "Expected image features but got pixel values") + if image_features is None: + return None + + if not isinstance(image_features, torch.Tensor): + raise ValueError("Incorrect type of image features") return LlavaImageFeatureInputs( type="image_features", - data=self._validate_image_data(data), + data=self._validate_image_data(image_features), ) return None @@ -201,12 +217,14 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): return self.multi_modal_projector(image_features) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - image_input: Optional[torch.Tensor] = None) -> SamplerOutput: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> SamplerOutput: """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the @@ -227,10 +245,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. - The model takes two types of image inputs: + The model takes two types of image inputs: PIXEL_VALUES and IMAGE_FEATURES. The following shows how each maps to huggingface implementation. - PIXEL_VALUES: + PIXEL_VALUES: - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353 IMAGE_FEATURES: - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430 @@ -239,14 +257,15 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - image_input: A batch of image inputs. - For PIXEL_VALUES, expecting [1, 3, 336, 336]. - For IMAGE_FEATURES, expecting [1, 576, 1024]. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, 576, 1024]. """ - parsed_image_input = self._parse_and_validate_image_input(image_input) + image_input = self._parse_and_validate_image_input(**kwargs) - if parsed_image_input is not None: - vision_embeddings = self._process_image_input(parsed_image_input) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = _merge_vision_embeddings( diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py new file mode 100644 index 0000000000000..270012e7d1c3b --- /dev/null +++ b/vllm/multimodal/__init__.py @@ -0,0 +1,7 @@ +from .base import MultiModalData, MultiModalPlugin +from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry + +__all__ = [ + "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", + "MultiModalRegistry" +] diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py new file mode 100644 index 0000000000000..847752449ba80 --- /dev/null +++ b/vllm/multimodal/base.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, + TypeVar) + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger + +if TYPE_CHECKING: + import torch + from torch import nn + +logger = init_logger(__name__) + + +class MultiModalData: + """ + Base class that contains multi-modal data. + + To add a new modality, add a new file under ``multimodal`` directory. + + In this new file, subclass :class:`~MultiModalData` and + :class:`~MultiModalPlugin`. + + Finally, register the new plugin to + :const:`vllm.multimodal.MULTIMODAL_REGISTRY`. + This enables models to call :meth:`MultiModalRegistry.register_input` for + the new modality. + """ + pass + + +D = TypeVar("D", bound=MultiModalData) +N = TypeVar("N", bound=Type["nn.Module"]) + +MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], + Dict[str, "torch.Tensor"]] +"""Return a dictionary to be passed as keyword arguments to +:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers +and processors in HuggingFace Transformers.""" + + +class MultiModalPlugin(ABC, Generic[D]): + """ + Base class that defines data processing logic for a specific modality. + + In particular, we adopt a registry pattern to dispatch data processing + according to the model being used (considering that different models may + process the same data differently). This registry is in turn used by + :class:`~MultiModalRegistry` which acts at a higher level + (i.e., the modality of the data). + """ + + @classmethod + def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]: + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + return get_model_architecture(model_config)[0] + + def __init__(self) -> None: + self._input_processors: Dict[Type["nn.Module"], + MultiModalInputProcessor[D]] = {} + + @abstractmethod + def get_data_type(self) -> Type[D]: + """ + Get the modality (subclass of :class:`~MultiModalData`) served by + this plugin. + """ + raise NotImplementedError + + @abstractmethod + def _default_input_processor( + self, data: D, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + """Return a dictionary to be passed as keyword arguments to + :meth:`torch.nn.Module.forward`. This is similar in concept to + tokenizers and processors in HuggingFace Transformers. + """ + raise NotImplementedError + + def register_input_processor(self, + processor: Optional[ + MultiModalInputProcessor[D]] = None): + """ + Register an input processor to a model class. + + When the model receives input data that matches the modality served by + this plugin (see :meth:`get_data_type`), the provided input processor is + applied to preprocess the data. If `None` is provided, then the default + input processor is applied instead. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_processors: + logger.warning( + "Model class %s already has an input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors[model_cls] = processor \ + or self._default_input_processor + + return model_cls + + return wrapper + + def process_input( + self, data: D, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + """ + Apply an input processor to a :class:`~MultiModalData` instance passed + to the model. + + The model is identified by ``model_config``. ``vlm_config`` is + for compatibility purposes and may be merged into ``model_config`` + in the near future. + """ + model_cls = self.get_model_cls(model_config) + + processor = self._input_processors.get(model_cls) + if processor is None: + raise KeyError(f"No input processor in {self} is registered for " + f"model class {model_cls.__name__}.") + + return processor(data, model_config, vlm_config) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py new file mode 100644 index 0000000000000..b964e9ee42624 --- /dev/null +++ b/vllm/multimodal/image.py @@ -0,0 +1,141 @@ +from typing import Dict, Tuple, Type, Union + +import torch +from PIL import Image + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger +from vllm.sequence import SequenceData +from vllm.transformers_utils.image_processor import cached_get_image_processor + +from .base import MultiModalData, MultiModalPlugin + +logger = init_logger(__name__) + + +def _get_dummy_seq_data(seq_len: int, + vlm_config: VisionLanguageConfig) -> SequenceData: + # NOTE: We assume that token is repeated `image_feature_size` times + # and then concatenated with the text prompt + # TODO: Enable other ways of inserting the image into the prompt + + token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size + token_ids += [0] * (seq_len - vlm_config.image_feature_size) + + return SequenceData(token_ids) + + +def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor: + if vlm_config.image_processor is None: + values_dtype = torch.float16 + else: + values_dtype = torch.uint8 + + return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype) + + +def get_dummy_image_data( + seq_len: int, + model_config: ModelConfig, + vlm_config: VisionLanguageConfig, +) -> Tuple[SequenceData, MultiModalData]: + """Standard dummy data factory for image data (to be used in + :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`).""" + seq_data = _get_dummy_seq_data(seq_len, vlm_config) + values = _get_dummy_values(vlm_config) + + config_input_type = vlm_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + fake_mm_data: MultiModalData + if config_input_type == ImageInputType.PIXEL_VALUES: + fake_mm_data = ImagePixelData(values) + elif config_input_type == ImageInputType.IMAGE_FEATURES: + fake_mm_data = ImageFeatureData(values) + else: + raise NotImplementedError + + return seq_data, fake_mm_data + + +class ImagePixelData(MultiModalData): + """ + The pixel data of an image. Can be one of: + + - :class:``PIL.Image``: An image object. Requires that a HuggingFace + processor is available to the model. + - :class:``torch.Tensor``: The raw pixel data which is passed to the model + without additional pre-processing. + """ + + def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None: + if isinstance(image, Image.Image): + # So that this class can be created inside the Image context manager + image.load() + + self.image = image + + +class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): + + def get_data_type(self) -> Type[ImagePixelData]: + return ImagePixelData + + def _get_hf_image_processor(self, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + if vlm_config is None or vlm_config.image_processor is None: + return None + + return cached_get_image_processor( + vlm_config.image_processor, + trust_remote_code=model_config.trust_remote_code, + revision=vlm_config.image_processor_revision, + ) + + def _default_input_processor( + self, data: ImagePixelData, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + image = data.image + image_processor = self._get_hf_image_processor(model_config, + vlm_config) + + if isinstance(image, Image.Image): + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available" + "to process the image object") + try: + return image_processor.preprocess(image, return_tensors="pt") \ + .to(model_config.dtype).data + except Exception: + logger.error("Failed to process image (%s)", image) + raise + elif isinstance(image, torch.Tensor): + pixel_values = image.to(model_config.dtype) + + return {"pixel_values": pixel_values} + + raise TypeError(f"Invalid image type: {type(image)}") + + +class ImageFeatureData(MultiModalData): + """ + The feature vector of an image, passed directly to the model. + + This should be the output of the vision tower. + """ + + def __init__(self, image_features: torch.Tensor) -> None: + self.image_features = image_features + + +class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): + + def get_data_type(self) -> Type[ImageFeatureData]: + return ImageFeatureData + + def _default_input_processor( + self, data: ImageFeatureData, model_config: ModelConfig, + vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + image_features = data.image_features.to(model_config.dtype) + + return {"image_features": image_features} diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py new file mode 100644 index 0000000000000..4789ce5ce4cfe --- /dev/null +++ b/vllm/multimodal/registry.py @@ -0,0 +1,156 @@ +import functools +from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, + Tuple, Type, TypeVar) + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger + +from .base import MultiModalData, MultiModalPlugin +from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, + ImagePixelPlugin) + +if TYPE_CHECKING: + import torch + from torch import nn + + from vllm.sequence import SequenceData + +logger = init_logger(__name__) + +D = TypeVar("D", bound=MultiModalData) +N = TypeVar("N", bound=Type["nn.Module"]) + +MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], + Dict[str, "torch.Tensor"]] +MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig], + Tuple["SequenceData", MultiModalData]] + + +class MultiModalRegistry: + """ + This registry is used by model runners to dispatch data processing + according to its modality and the target model. + """ + + DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) + + def __init__(self, + *, + plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS + ) -> None: + self._plugins_by_data_type = {p.get_data_type(): p for p in plugins} + self._dummy_factories_by_model_type: Dict[Type["nn.Module"], + MultiModalDummyFactory] = {} + + def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None: + data_type = plugin.get_data_type() + + if data_type in self._plugins_by_data_type: + logger.warning( + "A plugin is already registered for data type %s, " + "and will be overwritten by the new plugin %s.", data_type, + plugin) + + self._plugins_by_data_type[data_type] = plugin + + def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]): + for typ in data_type.mro(): + plugin = self._plugins_by_data_type.get(typ) + if plugin is not None: + return plugin + + msg = f"Unknown multi-modal data type: {data_type}" + raise NotImplementedError(msg) + + def register_dummy_data(self, factory: MultiModalDummyFactory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The modality and shape of + the dummy data should be an upper bound of what the model would receive + at inference time. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_factories_by_model_type: + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + """Create dummy data for memory profiling.""" + model_cls = MultiModalPlugin.get_model_cls(model_config) + dummy_factory = self._dummy_factories_by_model_type.get(model_cls) + if dummy_factory is None: + msg = f"No dummy data defined for model class: {model_cls}" + raise NotImplementedError(msg) + + return dummy_factory(seq_len, model_config, vlm_config) + + def register_input( + self, + data_type: Type[D], + processor: Optional[MultiModalInputProcessor[D]] = None): + """ + Register an input processor for a specific modality to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self._get_plugin_for_data_type(data_type) \ + .register_input_processor(processor) + + def register_image_pixel_input( + self, + processor: Optional[ + MultiModalInputProcessor[ImagePixelData]] = None): + """ + Register an input processor for image pixel data to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self.register_input(ImagePixelData, processor) + + def register_image_feature_input( + self, + processor: Optional[ + MultiModalInputProcessor[ImageFeatureData]] = None): + """ + Register an input processor for image feature data to a model class. + + See :meth:`MultiModalPlugin.register_input_processor` for more details. + """ + return self.register_input(ImageFeatureData, processor) + + def process_input(self, data: MultiModalData, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + """ + Apply an input processor to a :class:`~MultiModalData` instance passed + to the model. + + See :meth:`MultiModalPlugin.process_input` for more details. + """ + return self._get_plugin_for_data_type(type(data)) \ + .process_input(data, model_config, vlm_config) + + def create_input_processor(self, model_config: ModelConfig, + vlm_config: VisionLanguageConfig): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + return functools.partial(self.process_input, + model_config=model_config, + vlm_config=vlm_config) + + +MULTIMODAL_REGISTRY = MultiModalRegistry() +"""The global :class:`~MultiModalRegistry` which is used by model runners.""" diff --git a/vllm/sequence.py b/vllm/sequence.py index ac5c234d052bd..2f27bf33b166e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +import torch + from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest @@ -12,8 +14,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - import torch - + from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -398,25 +399,6 @@ class SequenceGroupState: generator: Optional = None # type: ignore -class MultiModalData: - """Multi modal request. - - Args: - type: The data type. - data: The actual data. - The required shape and semantic meaning of it depends on the vision - language config of the hosted model. - See `VisionLanguageConfig` in `config.py`. - """ - - class Type(enum.Enum): - IMAGE = enum.auto() - - def __init__(self, type: Type, data: "torch.Tensor"): - self.type = type - self.data = data - - class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -473,7 +455,7 @@ class SequenceGroup: return next(iter(self.seqs_dict.values())).prompt_token_ids @property - def multi_modal_data(self) -> Optional[MultiModalData]: + def multi_modal_data(self) -> Optional["MultiModalData"]: # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. return next(iter(self.seqs_dict.values())).multi_modal_data @@ -655,7 +637,7 @@ class SequenceGroupMetadata: lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional[MultiModalData] = None, + multi_modal_data: Optional["MultiModalData"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: @@ -798,13 +780,13 @@ class SamplerOutput: outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional["torch.Tensor"] = None + sampled_token_probs: Optional[torch.Tensor] = None # On-device tensor containing the logprobs of each token. logprobs: Optional["torch.Tensor"] = None # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional["torch.Tensor"] = None + sampled_token_ids: Optional[torch.Tensor] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py new file mode 100644 index 0000000000000..3239b1d0cfa2f --- /dev/null +++ b/vllm/transformers_utils/image_processor.py @@ -0,0 +1,45 @@ +from functools import lru_cache +from typing import Optional + +from transformers import AutoImageProcessor +from transformers.image_processing_utils import BaseImageProcessor + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_image_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +) -> BaseImageProcessor: + """Gets an image processor for the given model name via HuggingFace.""" + try: + processor: BaseImageProcessor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor + + +cached_get_image_processor = lru_cache(get_image_processor) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index bc88f2c5bed6c..eaf43247d4fc5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Tuple import torch from torch import nn @@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad @@ -63,6 +65,16 @@ class CPUModelRunner: self.block_size, ) + # Create processor for multi-modal data + if self.vision_language_config is not None: + self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ + .create_input_processor( + self.model_config, + self.vision_language_config, + ) + else: + self.multi_modal_input_processor = None + # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -80,14 +92,15 @@ class CPUModelRunner: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[ + str, torch.Tensor]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_input_list: List[torch.Tensor] = [] + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -108,9 +121,17 @@ class CPUModelRunner: # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -134,14 +155,10 @@ class CPUModelRunner: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(self.device) + for k, v in multi_modal_kwargs_list.items() + } num_prompt_tokens = len(input_tokens) @@ -167,7 +184,7 @@ class CPUModelRunner: slot_mapping=slot_mapping, ) return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_input) + multi_modal_kwargs) def _prepare_decode( self, @@ -257,8 +274,8 @@ class CPUModelRunner: self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Optional[torch.Tensor]]: - multi_modal_input = None + Optional[Dict[str, torch.Tensor]]]: + multi_modal_kwargs = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -266,7 +283,7 @@ class CPUModelRunner: # Prepare input tensors. if is_prompt: (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_input + multi_modal_kwargs ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, @@ -307,7 +324,7 @@ class CPUModelRunner: ) return (input_tokens, input_positions, attn_metadata, - sampling_metadata, multi_modal_input) + sampling_metadata, multi_modal_kwargs) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 0ba1200696cab..465130d10e2f9 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner): self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. @@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner): _, lora_mapping, lora_requests, - multi_modal_input, + multi_modal_kwargs, slot_mapping, num_prefill_tokens, num_decode_tokens, @@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner): "input_positions": input_positions, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, + "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, @@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner): input_positions = metadata_dict.pop("input_positions") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_input = metadata_dict.pop("multi_modal_input") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner): prompt_lens=None) return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input) + lora_requests, lora_mapping, multi_modal_kwargs) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47aa70dc617af..63ec22d79694f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,6 @@ import time import warnings +from collections import defaultdict from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np @@ -18,9 +19,9 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -44,7 +45,7 @@ class ModelInput(NamedTuple): query_lens: List[int] lora_mapping: Optional[LoRAMapping] lora_requests: Set[LoRARequest] - multi_modal_input: Optional[torch.Tensor] + multi_modal_kwargs: Dict[str, torch.Tensor] slot_mapping: torch.Tensor num_prefill_tokens: int num_decode_tokens: int @@ -60,7 +61,7 @@ class ModelInput(NamedTuple): query_lens=[], lora_mapping=None, lora_requests=set(), - multi_modal_input=None, + multi_modal_kwargs={}, slot_mapping=torch.empty(0, device=device), num_prefill_tokens=0, num_decode_tokens=0, @@ -122,6 +123,16 @@ class ModelRunner: self.block_size, ) + # Create processor for multi-modal data + if self.vision_language_config is not None: + self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ + .create_input_processor( + self.model_config, + self.vision_language_config, + ) + else: + self.multi_modal_input_processor = None + # Lazy initialization self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. @@ -242,7 +253,8 @@ class ModelRunner: context_lens: List[int] = [] query_lens: List[int] = [] block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] + multi_modal_kwargs_list: Dict[str, + List[torch.Tensor]] = defaultdict(list) decode_only = True num_prefills = 0 num_prefill_tokens = 0 @@ -417,9 +429,17 @@ class ModelRunner: and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) + mm_data = seq_group_metadata.multi_modal_data + if mm_data is not None: + # Process multi-modal data + if self.multi_modal_input_processor is None: + raise ValueError( + "Multi-modal inputs are only supported by " + "vision language models.") + + mm_kwargs = self.multi_modal_input_processor(mm_data) + for k, v in mm_kwargs.items(): + multi_modal_kwargs_list[k].append(v) if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not @@ -508,16 +528,6 @@ class ModelRunner: context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - query_lens_tensor = torch.tensor(query_lens, dtype=torch.long, device=self.device) @@ -614,6 +624,11 @@ class ModelRunner: else: lora_mapping = None + multi_modal_kwargs = { + k: torch.cat(v, dim=0).to(self.device) + for k, v in multi_modal_kwargs_list.items() + } + return ModelInput( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, @@ -622,7 +637,7 @@ class ModelRunner: query_lens=query_lens, lora_mapping=lora_mapping, lora_requests=lora_requests, - multi_modal_input=multi_modal_input, + multi_modal_kwargs=multi_modal_kwargs, slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -633,7 +648,7 @@ class ModelRunner: self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. @@ -645,7 +660,7 @@ class ModelRunner: query_lens, lora_mapping, lora_requests, - multi_modal_input, + multi_modal_kwargs, slot_mapping, num_prefill_tokens, num_decode_tokens, @@ -662,7 +677,7 @@ class ModelRunner: sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, - "multi_modal_input": multi_modal_input, + "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, @@ -679,7 +694,7 @@ class ModelRunner: "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") - multi_modal_input = metadata_dict.pop("multi_modal_input") + multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) @@ -694,7 +709,7 @@ class ModelRunner: return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, - multi_modal_input) + multi_modal_kwargs) @torch.inference_mode() def execute_model( @@ -703,7 +718,7 @@ class ModelRunner: kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_input + lora_requests, lora_mapping, multi_modal_kwargs ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: @@ -717,15 +732,14 @@ class ModelRunner: model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "kv_caches": kv_caches, - "attn_metadata": attn_metadata, - } - if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) - hidden_states = model_executable(**execute_model_kwargs) + + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **multi_modal_kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) @@ -781,16 +795,24 @@ class ModelRunner: # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. - if self.vision_language_config: + model_config = self.model_config + vlm_config = self.vision_language_config + + if vlm_config: max_num_seqs = min( max_num_seqs, - int(max_num_batched_tokens / - self.vision_language_config.image_feature_size)) + int(max_num_batched_tokens / vlm_config.image_feature_size)) for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data, fake_multi_modal_input = _prepare_fake_inputs( - seq_len, self.vision_language_config) + + if vlm_config is None: + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + else: + seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ + .dummy_data_for_profiling(seq_len, model_config, vlm_config) + seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -799,7 +821,7 @@ class ModelRunner: block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, - multi_modal_data=fake_multi_modal_input, + multi_modal_data=dummy_multi_modal_data, ) seqs.append(seq) @@ -1034,24 +1056,6 @@ def _get_graph_batch_size(batch_size: int) -> int: _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) -def _prepare_fake_inputs( - seq_len: int, vision_language_config: Optional[VisionLanguageConfig]): - """Prepare fake inputs for profile run.""" - if vision_language_config: - prompt_tokens = [ - vision_language_config.image_token_id - ] * vision_language_config.image_feature_size + [0] * ( - seq_len - vision_language_config.image_feature_size) - fake_image_input = MultiModalData( - type=MultiModalData.Type.IMAGE, - data=torch.zeros(vision_language_config.image_input_shape, - dtype=torch.float16)) - else: - prompt_tokens = [0] * seq_len - fake_image_input = None - return SequenceData(prompt_tokens), fake_image_input - - def _is_block_tables_empty(block_tables: Union[None, Dict]): """ Check if block_tables is None or a dictionary with all None values. From 0ab278ca31028a7623098b3c7d615ad350663d05 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 3 Jun 2024 09:39:31 -0700 Subject: [PATCH 13/18] [Core] Remove unnecessary copies in flash attn backend (#5138) --- requirements-cuda.txt | 2 +- vllm/attention/backends/flash_attn.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 5109f17356178..3536179835967 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -6,4 +6,4 @@ ray >= 2.9 nvidia-ml-py # for pynvml package torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 -vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0 diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0b9d6283493f2..070c074e511bc 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - out = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key, v=value, @@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl): causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + out=output[:num_prefill_tokens], ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl): causal=True, alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, + out=output[:num_prefill_tokens], ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( + flash_attn_with_kvcache( decode_query.unsqueeze(1), key_cache, value_cache, @@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, - ).squeeze(1) + out=output[num_prefill_tokens:].unsqueeze(1), + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) From cbb2f59cc853731f5607ac0130bb6cdebfdc89c7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 3 Jun 2024 12:52:30 -0400 Subject: [PATCH 14/18] [Kernel] Pass a device pointer into the quantize kernel for the scales (#5159) --- csrc/ops.h | 4 ++-- .../compressed_tensors/int8_quant_kernels.cu | 15 +++++++++------ tests/kernels/test_int8_quant.py | 4 +++- vllm/_custom_ops.py | 2 +- .../compressed_tensors_w8a8_statictensor.py | 2 +- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 567d9fae4bd2a..4952e826ec8ac 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, #endif -void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, - float scale); +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 4902e4c23434c..11baa5d414c19 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -28,9 +28,10 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( const scalar_t* __restrict__ input, int8_t* __restrict__ out, - scale_type scale, const int hidden_size) { + const scale_type* scale_ptr, const int hidden_size) { const int tid = threadIdx.x; const int token_idx = blockIdx.x; + scale_type scale = *scale_ptr; for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = @@ -39,11 +40,13 @@ __global__ void static_scaled_int8_quant_kernel( } } // namespace vllm -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - float scale) { +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& scale) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { vllm::static_scaled_int8_quant_kernel <<>>(input.data_ptr(), - out.data_ptr(), scale, - hidden_size); + out.data_ptr(), + scale.data_ptr(), hidden_size); }); } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index b9aa00ce13f56..29890118c93dc 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, torch.iinfo(torch.int8).min, torch.iinfo(torch.int8).max).to(torch.int8) out2 = torch.empty_like(x, dtype=torch.int8) - ops.static_scaled_int8_quant(out2, x, scale) + scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") + + ops.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 22cf5a44e341f..8a6f6d96d81f3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -265,7 +265,7 @@ def scaled_fp8_quant( # int8 def static_scaled_int8_quant(input: torch.Tensor, - scale: float) -> torch.Tensor: + scale: torch.Tensor) -> torch.Tensor: """ Quantize the input tensor to int8 and return the quantized tensor. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 7e3e932cfe14a..2dfc6e2b07782 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): act_scale = layer.input_scale # Input quantize - x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + x_q = custom_ops.static_scaled_int8_quant(x, act_scale) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype) From cafb8e06c5ffa359ac7fa4b53795e6eaa1a200c7 Mon Sep 17 00:00:00 2001 From: Yuan Date: Tue, 4 Jun 2024 01:39:50 +0800 Subject: [PATCH 15/18] [CI/BUILD] enable intel queue for longer CPU tests (#4113) --- .buildkite/run-cpu-test.sh | 14 +++- .buildkite/test-template.j2 | 2 + Dockerfile.cpu | 6 +- csrc/cpu/pos_encoding.cpp | 105 ++++++++++++++-------------- tests/conftest.py | 36 ++++++---- tests/models/test_aqlm.py | 11 +-- tests/models/test_big_models.py | 10 ++- tests/models/test_fp8.py | 11 +-- tests/models/test_gptq_marlin.py | 11 +-- tests/models/test_gptq_marlin_24.py | 11 +-- tests/models/test_marlin.py | 11 +-- 11 files changed, 138 insertions(+), 90 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 414045fe163e5..d1200ee84dfe4 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py +# Run the image +docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test + +# offline inference +docker exec cpu-test bash -c "python3 examples/offline_inference.py" + +# Run basic model test +docker exec cpu-test bash -c "cd tests; + pip install pytest Pillow protobuf + bash ../.buildkite/download-images.sh + cd ../ + pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 265833e2ccf6e..7e986c988407c 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -40,6 +40,8 @@ steps: - label: "Intel Test" depends_on: ~ + agents: + queue: intel command: bash .buildkite/run-cpu-test.sh {% for step in steps %} diff --git a/Dockerfile.cpu b/Dockerfile.cpu index aec79824213f3..ae23e27b413ba 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,6 +1,6 @@ # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. -FROM ubuntu:22.04 +FROM ubuntu:22.04 AS cpu-test-1 RUN apt-get update -y \ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ @@ -9,6 +9,8 @@ RUN apt-get update -y \ RUN pip install --upgrade pip \ && pip install wheel packaging ninja setuptools>=49.4.0 numpy +FROM cpu-test-1 AS build + COPY ./ /workspace/vllm WORKDIR /workspace/vllm @@ -19,4 +21,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + CMD ["/bin/bash"] diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 73bf77e46f538..e8aead17ae5a7 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -21,7 +21,57 @@ void rotary_embedding_impl( constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); const int embed_dim = rot_dim / 2; - TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + bool flag = (embed_dim % VEC_ELEM_NUM == 0); + const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; + + auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, + scalar_t* qk) { + int j = 0; + for (; j < loop_upper; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(qk + out_x); + const scalar_vec_t q_y(qk + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(qk + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(qk + out_y); + } + if (!flag) { + for (; j < embed_dim; ++j) { + const int x_index = j; + const int y_index = embed_dim + j; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const float fp32_cos = cache_ptr[x_index]; + const float fp32_sin = cache_ptr[y_index]; + + const float fp32_q_x = qk[out_x]; + const float fp32_q_y = qk[out_y]; + + qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + } + } + }; #pragma omp parallel for for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { @@ -32,62 +82,13 @@ void rotary_embedding_impl( const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; - - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); - - const scalar_vec_t q_x(query + out_x); - const scalar_vec_t q_y(query + out_y); - - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); - - vec_op::FP32Vec8 fp32_q_x(q_x); - vec_op::FP32Vec8 fp32_q_y(q_y); - - auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; - scalar_vec_t(out1).save(query + out_x); - - auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; - scalar_vec_t(out2).save(query + out_y); - } + compute_loop(token_head, cache_ptr, query); } for (int i = 0; i < num_kv_heads; ++i) { const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; - - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); - - const scalar_vec_t k_x(key + out_x); - const scalar_vec_t k_y(key + out_y); - - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); - - vec_op::FP32Vec8 fp32_k_x(k_x); - vec_op::FP32Vec8 fp32_k_y(k_y); - - auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; - scalar_vec_t(out1).save(key + out_x); - auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; - scalar_vec_t(out2).save(key + out_y); - } + compute_loop(token_head, cache_ptr, key); } } } diff --git a/tests/conftest.py b/tests/conftest.py index e749338e1095a..764374a779d9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu logger = init_logger(__name__) @@ -58,7 +59,8 @@ def cleanup(): with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() - torch.cuda.empty_cache() + if not is_cpu(): + torch.cuda.empty_cache() @pytest.fixture() @@ -151,6 +153,12 @@ _EMBEDDING_MODELS = [ class HfRunner: + def wrap_device(self, input: any): + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + def __init__( self, model_name: str, @@ -164,16 +172,18 @@ class HfRunner: if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer( - model_name, - device="cpu", - ).to(dtype=torch_dtype).cuda() + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype)) else: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() + self.model = self.wrap_device( + AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + )) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -214,7 +224,7 @@ class HfRunner: inputs = self.processor(**processor_kwargs) output_ids = self.model.generate( - **inputs.to("cuda"), + **self.wrap_device(inputs), use_cache=True, **kwargs, ) @@ -271,7 +281,7 @@ class HfRunner: for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids output = self.model.generate( - input_ids.cuda(), + self.wrap_device(input_ids), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -306,7 +316,7 @@ class HfRunner: for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids output = self.model.generate( - input_ids.cuda(), + self.wrap_device(input_ids), use_cache=True, do_sample=False, max_new_tokens=max_tokens, diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py index a7abc011f57d7..85d74f7f5b03d 100644 --- a/tests/models/test_aqlm.py +++ b/tests/models/test_aqlm.py @@ -8,10 +8,13 @@ import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -aqlm_not_supported = (capability < - QUANTIZATION_METHODS["aqlm"].get_min_capability()) +aqlm_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + aqlm_not_supported = (capability < + QUANTIZATION_METHODS["aqlm"].get_min_capability()) # In this test we hardcode prompts and generations for the model so we don't # need to require the AQLM package as a dependency diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 10e7c64e34e75..ea95e6a49f03a 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -5,6 +5,7 @@ This tests bigger models and use half precision. Run `pytest tests/models/test_big_models.py`. """ import pytest +import torch MODELS = [ "meta-llama/Llama-2-7b-hf", @@ -16,9 +17,14 @@ MODELS = [ # "Qwen/Qwen1.5-0.5B" # Broken, ] +#TODO: remove this after CPU float16 support ready +target_dtype = "float" +if torch.cuda.is_available(): + target_dtype = "half" + @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [32]) def test_models( hf_runner, @@ -46,7 +52,7 @@ def test_models( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", [target_dtype]) def test_model_print( vllm_runner, model: str, diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 0a5819ea3f054..61aee0d0a6e93 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -67,10 +67,13 @@ EXPECTED_STRS_MAP = { }, } -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -fp8_not_supported = (capability < - QUANTIZATION_METHODS["fp8"].get_min_capability()) +fp8_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) @pytest.mark.skipif(fp8_not_supported, diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index 1fc0b3f239127..814471b47763d 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -22,10 +22,13 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -gptq_marlin_not_supported = ( - capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) +gptq_marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + gptq_marlin_not_supported = ( + capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) MODELS = [ # act_order==False, group_size=channelwise diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py index 3e6ffb7f90fcc..cc35ee803ff01 100644 --- a/tests/models/test_gptq_marlin_24.py +++ b/tests/models/test_gptq_marlin_24.py @@ -14,10 +14,13 @@ import torch from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -marlin_not_supported = (capability < - QUANTIZATION_METHODS["marlin"].get_min_capability()) +marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + marlin_not_supported = ( + capability < QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 37c1664afec55..8520b26718bf5 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -19,10 +19,13 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from .utils import check_logprobs_close -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -marlin_not_supported = (capability < - QUANTIZATION_METHODS["marlin"].get_min_capability()) +marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + marlin_not_supported = ( + capability < QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass From 10c38e3e46d95a02576c92c25d5b3e88e3c6e282 Mon Sep 17 00:00:00 2001 From: Kaiyang Chen <48289729+Kaiyang-Chen@users.noreply.github.com> Date: Tue, 4 Jun 2024 04:37:11 +0800 Subject: [PATCH 16/18] [Misc]: Implement CPU/GPU swapping in BlockManagerV2 (#3834) --- format.sh | 2 +- tests/core/block/e2e/test_correctness.py | 49 ++++++-- tests/core/block/test_block_manager_v2.py | 58 ++++++++- vllm/config.py | 28 +++-- vllm/core/block/block_table.py | 4 + vllm/core/block/common.py | 1 - vllm/core/block/cpu_gpu_block_allocator.py | 82 +++++++++++- vllm/core/block/interfaces.py | 36 +++++- vllm/core/block/naive_block.py | 66 +++++++++- vllm/core/block/prefix_caching_block.py | 78 ++++++++++++ vllm/core/block_manager_v1.py | 6 +- vllm/core/block_manager_v2.py | 140 +++++++++++++++++++-- vllm/core/embedding_model_block_manager.py | 3 +- vllm/core/interfaces.py | 3 +- vllm/core/scheduler.py | 13 +- vllm/engine/arg_utils.py | 9 ++ 16 files changed, 529 insertions(+), 49 deletions(-) diff --git a/format.sh b/format.sh index ca828457f9999..6057b69af8ce8 100755 --- a/format.sh +++ b/format.sh @@ -118,7 +118,7 @@ mypy vllm/model_executor --config-file pyproject.toml # https://github.com/codespell-project/codespell/issues/1915 # Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem CODESPELL_EXCLUDES=( - '--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,tests/lora/data/**,build/**' + '--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**' ) # check spelling of specified files diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 3713ef2fed4d1..ad253635e0ba0 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -24,7 +24,13 @@ from .conftest import get_token_ids_from_llm_generator @pytest.mark.parametrize("baseline_llm_kwargs", [{ "use_v2_block_manager": False }]) -@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "preemption_mode": "swap" +}, { + "use_v2_block_manager": True, + "preemption_mode": "recompute" +}]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, @@ -95,7 +101,13 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, @pytest.mark.parametrize("baseline_llm_kwargs", [{ "use_v2_block_manager": False }]) -@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "preemption_mode": "swap" +}, { + "use_v2_block_manager": True, + "preemption_mode": "recompute" +}]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, @@ -179,11 +191,18 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, }]) @pytest.mark.parametrize( "test_llm_kwargs", - [{ - # We run one test with block_size < lookahead_slots, one test with - # block_size > lookahead_slots - "num_lookahead_slots": 10, - }]) + [ + { + # We run one test with block_size < lookahead_slots, one test with + # block_size > lookahead_slots + "num_lookahead_slots": 10, + "preemption_mode": "swap", + }, + { + "num_lookahead_slots": 10, + "preemption_mode": "recompute", + } + ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, @@ -322,7 +341,13 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, @pytest.mark.parametrize("baseline_llm_kwargs", [{ "use_v2_block_manager": False }]) -@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "preemption_mode": "swap" +}, { + "use_v2_block_manager": True, + "preemption_mode": "recompute" +}]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( @@ -397,7 +422,13 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( @pytest.mark.parametrize("baseline_llm_kwargs", [{ "enable_prefix_caching": False }]) -@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "enable_prefix_caching": True, + "preemption_mode": "swap" +}, { + "enable_prefix_caching": True, + "preemption_mode": "recompute" +}]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) def test_auto_prefix_caching_with_preemption(baseline_llm_generator, diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index f98fc0e217278..d0ca09c4be0d4 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -7,7 +7,8 @@ from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import create_seq_group, create_seq_group_encoder_decoder +from ..utils import (create_dummy_prompt, create_seq_group, + create_seq_group_encoder_decoder) @pytest.mark.parametrize("block_size", [16]) @@ -255,6 +256,61 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, assert num_consumed_blocks == expected_consumed_blocks +@pytest.mark.parametrize("block_size", [8]) +@pytest.mark.parametrize("num_cpu_blocks", [4]) +@pytest.mark.parametrize("num_gpu_blocks", [4]) +@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) +@pytest.mark.parametrize("enable_caching", [False, True]) +def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, + enable_caching): + """Verify blocks number on src/desc device is correct after swapping in/out + sequence group (not missing or extra blocks). + """ + block_manager = BlockSpaceManagerV2(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) + prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) + prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + prompt.status = SequenceStatus.RUNNING + prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap seq group from GPU -> CPU. + gpu_blocks = block_manager.get_block_table(prompt) + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + mapping_keys = [key for key, _ in mapping] + assert mapping_keys == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + prompt.status = SequenceStatus.SWAPPED + + # Swap seq group from CPU -> GPU. + assert block_manager.can_swap_in(seq_group, num_lookahead_slots) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + cpu_blocks = block_manager.get_block_table(prompt) + mapping_keys = [key for key, _ in mapping] + assert mapping_keys == [cpu_blocks[0]] + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + + +# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. + + @pytest.mark.parametrize("block_size", [8, 16]) @pytest.mark.parametrize("prompt_len", [10, 300, 1000]) @pytest.mark.parametrize("num_slots_to_append", [50]) diff --git a/vllm/config.py b/vllm/config.py index eee62d2683835..7fd417bd745a9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -651,19 +651,24 @@ class SchedulerConfig: enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. embedding_mode: Whether the running model is for embedding. + preemption_mode: Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead. """ - def __init__( - self, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - use_v2_block_manager: bool = False, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, - ) -> None: + def __init__(self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + use_v2_block_manager: bool = False, + num_lookahead_slots: int = 0, + delay_factor: float = 0.0, + enable_chunked_prefill: bool = False, + embedding_mode: Optional[bool] = False, + preemption_mode: Optional[str] = None) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -689,6 +694,7 @@ class SchedulerConfig: self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode + self.preemption_mode = preemption_mode self._verify_args() diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 26c704b8de901..26f378ba24b76 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -283,6 +283,10 @@ class BlockTable: def _is_allocated(self) -> bool: return len(self._blocks) > 0 + @property + def blocks(self) -> Optional[List[Block]]: + return self._blocks + @property def _num_empty_slots(self) -> int: assert self._is_allocated diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 4d7a12165cb01..d2787d69616f0 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -140,7 +140,6 @@ class CopyOnWriteTracker: assert refcount != 0 if refcount > 1: src_block_id = block_id - # Decrement refcount of the old block. self._allocator.free(block) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index d28a684376974..255aae9d17318 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -90,11 +90,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): gpu_block_allocator=gpu_allocator, ) - def __init__( - self, - cpu_block_allocator: BlockAllocator, - gpu_block_allocator: BlockAllocator, - ): + def __init__(self, cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator): assert not ( cpu_block_allocator.all_block_ids & gpu_block_allocator.all_block_ids @@ -105,6 +102,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): Device.GPU: gpu_block_allocator, } + self._swap_mapping: Dict[int, int] = {} self._null_block: Optional[Block] = None self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} @@ -198,6 +196,68 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): def get_num_total_blocks(self, device: Device) -> int: return self._allocators[device].get_num_total_blocks() + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + """Returns the zero-offset block id on certain device given the + absolute block id. + + Args: + device (Device): The device for which to query relative block id. + absolute_id (int): The absolute block id for the block in + whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return self._allocators[device].get_physical_block_id(absolute_id) + + def swap(self, blocks: List[Block], source_device: Device, + dest_device: Device) -> Dict[int, int]: + """Execute the swap for the given blocks from source_device + on to dest_device, save the current swap mapping and append + them to the accumulated `self._swap_mapping` for each + scheduling move. + + Args: + blocks: List of blocks to be swapped. + source_device (Device): Device to swap the 'blocks' from. + dest_device (Device): Device to swap the 'blocks' to. + + Returns: + Dict[int, int]: Swap mapping from source_device + on to dest_device. + """ + source_block_ids = [block.block_id for block in blocks] + self._allocators[source_device].swap_out(blocks) + self._allocators[dest_device].swap_in(blocks) + dest_block_ids = [block.block_id for block in blocks] + + current_swap_mapping: Dict[int, int] = {} + for src, dest in zip(source_block_ids, dest_block_ids): + if src is not None and dest is not None: + self._swap_mapping[src] = dest + current_swap_mapping[src] = dest + return current_swap_mapping + + def get_num_blocks_touched(self, + blocks: List[Block], + device: Device, + num_lookahead_slots: int = 0) -> int: + """Returns the number of blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + + Args: + blocks: List of blocks to be swapped. + device (Device): Device to swap the 'blocks' on. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + int: the number of blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + """ + return self._allocators[device].get_num_blocks_touched( + blocks, num_lookahead_slots) + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. @@ -240,6 +300,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: raise NotImplementedError + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Returns: + List[Tuple[int, int]]: A mapping of source to destination block IDs. + """ + mapping = self._swap_mapping.copy() + self._swap_mapping.clear() + return list(mapping.items()) + class NullBlock(Block): """ diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 8fc4c601106cd..4b20856a1b42d 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import FrozenSet, List, Optional, Protocol, Tuple +from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple from vllm.utils import Device @@ -116,6 +116,18 @@ class BlockAllocator(ABC): def get_num_free_blocks(self) -> int: pass + @abstractmethod + def get_physical_block_id(self, absolute_id: int) -> int: + pass + + @abstractmethod + def swap_out(self, blocks: List[Block]) -> None: + pass + + @abstractmethod + def swap_in(self, blocks: List[Block]) -> None: + pass + @property @abstractmethod def all_block_ids(self) -> FrozenSet[int]: @@ -149,6 +161,12 @@ class BlockAllocator(ABC): """NOTE: This should not be used besides Block""" pass + @abstractmethod + def get_num_blocks_touched(self, + blocks: List[Block], + num_lookahead_slots: int = 0) -> int: + pass + class NoFreeBlocksError(ValueError): pass @@ -204,6 +222,22 @@ class DeviceAwareBlockAllocator(ABC): self, seq_block_ids: List[List[int]]) -> List[int]: pass + @abstractmethod + def get_num_blocks_touched(self, + blocks: List[Block], + device: Device, + num_lookahead_slots: int = 0) -> int: + pass + + @abstractmethod + def swap(self, blocks: List[Block], source_device: Device, + dest_device: Device) -> Dict[int, int]: + pass + + @abstractmethod + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + pass + @abstractmethod def allocate_or_get_null_block(self) -> Block: """ diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index ae01930878254..d033787122d7a 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -3,6 +3,7 @@ from typing import FrozenSet, Iterable, List, Optional, Set, Tuple from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.utils import cdiv Refcount = int @@ -95,8 +96,6 @@ class NaiveBlockAllocator(BlockAllocator): def free(self, block: Block) -> None: assert block.block_id is not None self._free_block_id(block.block_id) - - # Mark the block as having no allocation. block.block_id = None def fork(self, last_block: Block) -> List[Block]: @@ -153,6 +152,19 @@ class NaiveBlockAllocator(BlockAllocator): if refcount == 0: self._free_block_indices.add(block_id) + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return sorted(self._all_block_indices).index(absolute_id) + @property def refcounter(self): return self._refcounter @@ -213,6 +225,56 @@ class NaiveBlockAllocator(BlockAllocator): def promote_to_immutable_block(self, block: Block) -> BlockId: raise NotImplementedError + def get_num_blocks_touched(self, + blocks: List[Block], + num_lookahead_slots: int = 0) -> int: + """Determine the number of blocks that will be touched by + swapping in/out the given blocks from certain sequence + group with the provided num_lookahead_slots. + + Args: + blocks (List[Block]): The potential blocks to swap. + num_lookahead_slots (int): number of lookahead slots (0 for swap + out). + + Returns: + int: the number of blocks that will be touched by + swapping in/out the given blocks and num_lookahead_slots. + """ + # NOTE: for naive block, we use set to eliminate common blocks among + # seqs, also we compare the empty slots in the mutable blocks with + # lookahead slots to get the number of unique new block that are + # needed. + old_block_set = set() + new_block_count = 0 + # TODO(cade): make sure the logic is correct and clean it up. + for block in blocks: + if not block.is_full and num_lookahead_slots != 0: + if block.num_empty_slots >= num_lookahead_slots: + new_block_count += 1 + else: + new_block_count += cdiv( + num_lookahead_slots - block.num_empty_slots, + self._block_size) + else: + old_block_set.add(block.block_id) + num_touched_blocks = new_block_count + len(old_block_set) + return num_touched_blocks + + def swap_out(self, blocks: List[Block]) -> None: + for block in blocks: + self.free(block) + + def swap_in(self, blocks: List[Block]) -> None: + for block in blocks: + if block.is_full: + alloc = self.allocate_immutable(block.prev_block, + block.token_ids) + else: + alloc = self.allocate_mutable(block.prev_block) + alloc.append_token_ids(block.token_ids) + block.block_id = alloc.block_id + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 4eb32f145b05b..405e9705659df 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,4 +1,5 @@ """Token blocks.""" + from itertools import takewhile from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -8,6 +9,7 @@ from vllm.core.block.common import (CopyOnWriteTracker, from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor +from vllm.utils import cdiv PrefixHash = int @@ -294,10 +296,29 @@ class PrefixCachingBlockAllocator(BlockAllocator): def get_num_total_blocks(self) -> int: return self._hashless_allocator.get_num_total_blocks() + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The rzero-offset block id on certain device. + """ + return sorted(self.all_block_ids).index(absolute_id) + @property def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids + def is_block_cached(self, block: Block) -> bool: + assert block.content_hash is not None + if block.content_hash in self._cached_blocks: + return True + return False + def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable block. This means that its content can be referenced by future blocks @@ -411,6 +432,63 @@ class PrefixCachingBlockAllocator(BlockAllocator): if ids != [] ]) + def get_num_blocks_touched(self, + blocks: List[Block], + num_lookahead_slots: int = 0) -> int: + """Determine the number of blocks that will be touched by + swapping in/out the given blocks from certain sequence + group with the provided num_lookahead_slots. + + Args: + blocks (List[Block]): The potential blocks to swap. + num_lookahead_slots (int): number of lookahead slots (0 for + swap out). + + Returns: + int: the number of blocks that will be touched by + swapping in/out the given blocks and num_lookahead_slots. + """ + num_touched_blocks = 0 + for block in blocks: + if not block.is_full: + if block.num_empty_slots >= num_lookahead_slots: + num_touched_blocks += 1 + else: + num_touched_blocks += cdiv( + num_lookahead_slots - block.num_empty_slots, + self._block_size) + else: + if not self.is_block_cached(block): + num_touched_blocks += 1 + return num_touched_blocks + + def swap_out(self, blocks: List[Block]) -> None: + """Execute the swap out actions. Basically just free the + given blocks. + + Args: + blocks: List of blocks to be swapped out. + """ + for block in blocks: + self.free(block) + + def swap_in(self, blocks: List[Block]) -> None: + """Execute the swap int actions. Change the block id from + old allocator to current allocator for each block to finish + the block table update. + + Args: + blocks: List of blocks to be swapped in. + """ + for block in blocks: + if block.is_full: + alloc = self.allocate_immutable(block.prev_block, + block.token_ids) + else: + alloc = self.allocate_mutable(block.prev_block) + alloc.append_token_ids(block.token_ids) + block.block_id = alloc.block_id + class PrefixCachingBlock(Block): """A block implementation that supports prefix caching. diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 201cba309f6ef..4010aaf02b828 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -541,11 +541,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): return new_block_table - def swap_in(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: - assert (num_lookahead_slots == 0 - ), "BlockSpaceManagerV1 does not support lookahead allocation" + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index cad42ab3c1ba2..121092cf189bd 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,10 +1,12 @@ """A block manager that manages token blocks.""" +from itertools import chain from typing import Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -217,7 +219,6 @@ class BlockSpaceManagerV2(BlockSpaceManager): num_lookahead_slots=num_lookahead_slots, num_computed_slots=seq.data.get_num_computed_tokens(), ) - # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() return new_cows @@ -297,20 +298,145 @@ class BlockSpaceManagerV2(BlockSpaceManager): def can_swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> AllocStatus: - return AllocStatus.LATER + """Returns the AllocStatus for the given sequence_group + with num_lookahead_slots. - def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> List[Tuple[int, int]]: - raise NotImplementedError + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for the given sequence group. + """ + return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, + num_lookahead_slots) + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from CPU to GPU) generated by + swapping in the given seq_group with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from CPU + to GPU. + """ + blocks = self._get_blocks_for_swap(seq_group, SequenceStatus.SWAPPED) + current_swap_mapping = self.block_allocator.swap( + blocks=blocks, source_device=Device.CPU, dest_device=Device.GPU) + + block_number_mapping = { + self.block_allocator.get_physical_block_id(Device.CPU, + cpu_block_id): + self.block_allocator.get_physical_block_id(Device.GPU, + gpu_block_id) + for cpu_block_id, gpu_block_id in current_swap_mapping.items() + } + # convert to list of tuples once here + return list(block_number_mapping.items()) def can_swap_out(self, seq_group: SequenceGroup) -> bool: + """Returns whether we can swap out the given sequence_group + with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + bool: Whether it's possible to swap out current sequence group. + """ + alloc_status = self._can_swap(seq_group, Device.CPU, + SequenceStatus.RUNNING) + if alloc_status == AllocStatus.OK: + return True return False - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - raise NotImplementedError + def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from GPU to CPU) generated by + swapping out the given sequence_group with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from + GPU to CPU. + """ + blocks = self._get_blocks_for_swap(sequence_group, + SequenceStatus.RUNNING) + current_swap_mapping = self.block_allocator.swap( + blocks=blocks, source_device=Device.GPU, dest_device=Device.CPU) + block_number_mapping = { + self.block_allocator.get_physical_block_id(Device.GPU, + gpu_block_id): + self.block_allocator.get_physical_block_id(Device.CPU, + cpu_block_id) + for gpu_block_id, cpu_block_id in current_swap_mapping.items() + } + # convert to list of tuples once here + return list(block_number_mapping.items()) def get_num_free_gpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.GPU) def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + + def _can_swap(self, + seq_group: SequenceGroup, + device: Device, + status: SequenceStatus, + num_lookahead_slots: int = 0) -> AllocStatus: + """Returns the AllocStatus for swapping in/out the given sequence_group + on to the 'device'. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + device (Device): device to swap the 'seq_group' on. + status (SequenceStatus): The status of sequence which is needed + for action. RUNNING for swap out and SWAPPED for swap in + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for swapping in/out the given + sequence_group on to the 'device'. + """ + blocks = self._get_blocks_for_swap(seq_group, status) + num_blocks_touched = self.block_allocator.get_num_blocks_touched( + blocks, device, num_lookahead_slots) + watermark_blocks = 0 + if device == Device.GPU: + watermark_blocks = self.watermark_blocks + if self.block_allocator.get_num_total_blocks( + device) < num_blocks_touched: + return AllocStatus.NEVER + elif self.block_allocator.get_num_free_blocks( + device) - num_blocks_touched >= watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _get_blocks_for_swap(self, seq_group: SequenceGroup, + status: SequenceStatus) -> List[Block]: + """Returns the list of blocks those are touched by the seq_group + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + status (SequenceStatus): The status of sequence which is needed + for action. RUNNING for swap out and SWAPPED for swap in + + Returns: + The list of blocks those are touched by the seq_group. + """ + blocks: Dict[int, List[Block]] = {} + for seq in seq_group.get_seqs(status=status): + block_table = self.block_tables[seq.seq_id] + if block_table.blocks is not None: + blocks[seq.seq_id] = block_table.blocks + combined_blocks = list(chain(*blocks.values())) + return combined_blocks diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index a09d79ec3c420..f2d67306d7ceb 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -46,8 +46,7 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager): num_lookahead_slots: int) -> AllocStatus: return AllocStatus.OK - def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> List[Tuple[int, int]]: + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: return None # type: ignore def can_swap_out(self, seq_group: SequenceGroup) -> bool: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 689cbc2179ee1..8759ee06795b8 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -73,8 +73,7 @@ class BlockSpaceManager(ABC): pass @abstractmethod - def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> List[Tuple[int, int]]: + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7c70b1b244f7d..399665082f838 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -297,6 +297,8 @@ class Scheduler: self.prev_prompt = False # Latency of the last prompt step self.last_prompt_latency = 0.0 + # preemption mode, RECOMPUTE or SWAP + self.user_specified_preemption_mode = scheduler_config.preemption_mode # The following field is test-only. It is used to inject artificial # preemption. @@ -522,7 +524,9 @@ class Scheduler: seq_group = swapped_queue[0] # If the sequence group cannot be swapped in, stop. - alloc_status = self.block_manager.can_swap_in(seq_group) + is_prefill = seq_group.is_prefill() + alloc_status = self.block_manager.can_swap_in( + seq_group, self._get_num_lookahead_slots(is_prefill)) if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -1067,12 +1071,17 @@ class Scheduler: # over sequence groups with a single sequence. # TODO(woosuk): Support recomputation for sequence groups with multiple # sequences. This may require a more sophisticated CUDA kernel. - if preemption_mode is None: + if self.user_specified_preemption_mode is None: if seq_group.get_max_num_running_seqs() == 1: preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP + elif self.user_specified_preemption_mode == "swap": + preemption_mode = PreemptionMode.SWAP + else: + preemption_mode = PreemptionMode.RECOMPUTE + if self.num_cumulative_preemption % 50 == 0: logger.warning( "Sequence group %s is preempted by %s mode because there is " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b315d4d2ece29..72787d369c0f8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -75,6 +75,7 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 model_loader_extra_config: Optional[dict] = None + preemption_mode: Optional[str] = None # Related to Vision-language models such as llava image_input_type: Optional[str] = None @@ -564,6 +565,13 @@ class EngineArgs: 'corresponding to the chosen load_format. ' 'This should be a JSON string that will be ' 'parsed into a dictionary.') + parser.add_argument( + '--preemption_mode', + type=str, + default=None, + help='If \'recompute\', the engine performs preemption by block ' + 'swapping; If \'swap\', the engine performs preemption by block ' + 'swapping.') parser.add_argument( "--served-model-name", @@ -667,6 +675,7 @@ class EngineArgs: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + preemption_mode=self.preemption_mode, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, From 4f0d17c05cdb220f2f45a20e956f766dec29acbc Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 3 Jun 2024 16:16:43 -0700 Subject: [PATCH 17/18] New CI template on AWS stack (#5110) Signed-off-by: kevin --- .buildkite/test-template-aws.j2 | 59 +++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 .buildkite/test-template-aws.j2 diff --git a/.buildkite/test-template-aws.j2 b/.buildkite/test-template-aws.j2 new file mode 100644 index 0000000000000..9f7d07acca298 --- /dev/null +++ b/.buildkite/test-template-aws.j2 @@ -0,0 +1,59 @@ +{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %} +{% set default_working_dir = "/vllm-workspace/tests" %} + +steps: + - label: ":docker: build image" + agents: + queue: cpu_queue + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." + - "docker push {{ docker_image }}" + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 + - wait + + {% for step in steps %} + - label: "{{ step.label }}" + agents: + {% if step.no_gpu %} + queue: cpu_queue + {% elif step.num_gpus == 2 or step.num_gpus == 4 %} + queue: gpu_4_queue + {% else %} + queue: gpu_1_queue + {% endif %} + soft_fail: true + {% if step.parallelism %} + parallelism: {{ step.parallelism }} + {% endif %} + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 + plugins: + - docker#v5.2.0: + image: {{ docker_image }} + always-pull: true + propagate-environment: true + {% if not step.no_gpu %} + gpus: all + {% endif %} + command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"] + environment: + - VLLM_USAGE_SOURCE=ci-test + - HF_TOKEN + {% if step.label == "Speculative decoding tests" %} + - VLLM_ATTENTION_BACKEND=XFORMERS + {% endif %} + volumes: + - /dev/shm:/dev/shm + {% endfor %} From f775a07e30fdeafc14f53fe502b262b00540dd71 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 4 Jun 2024 01:25:29 +0200 Subject: [PATCH 18/18] [FRONTEND] OpenAI `tools` support named functions (#5032) --- .../serving/openai_compatible_server.md | 13 +- tests/entrypoints/test_openai_server.py | 185 ++++++++++++++++++ tests/utils.py | 3 +- vllm/entrypoints/openai/protocol.py | 57 +++++- vllm/entrypoints/openai/serving_chat.py | 37 +++- .../guided_decoding/__init__.py | 30 ++- 6 files changed, 314 insertions(+), 11 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 15a8761eb5738..a912949352b86 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -109,4 +109,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :module: vllm.entrypoints.openai.cli_args :func: make_arg_parser :prog: -m vllm.entrypoints.openai.api_server -``` \ No newline at end of file +``` + +## Tool calling in the chat completion API +vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. + +To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter. + +It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.** + +vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. + +Please refer to the OpenAI API reference documentation for more information. \ No newline at end of file diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 972137030f46f..bff2487117837 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -906,6 +906,191 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token in top_logprobs) +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_named_tool_use(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {TEST_SCHEMA}" + }] + + # non-streaming + + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice={ + "type": "function", + "function": { + "name": "dummy_function_name" + } + }) + message = chat_completion.choices[0].message + assert len(message.content) == 0 + json_string = message.tool_calls[0].function.arguments + json1 = json.loads(json_string) + jsonschema.validate(instance=json1, schema=TEST_SCHEMA) + + messages.append({"role": "assistant", "content": json_string}) + messages.append({ + "role": + "user", + "content": + "Give me another one with a different name and age" + }) + + # streaming + + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice={ + "type": "function", + "function": { + "name": "dummy_function_name" + } + }, + stream=True) + + output = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + assert delta.content is None or len(delta.content) == 0 + if delta.tool_calls: + output.append(delta.tool_calls[0].function.arguments) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + json2 = json.loads("".join(output)) + jsonschema.validate(instance=json2, schema=TEST_SCHEMA) + assert json1["name"] != json2["name"] + assert json1["age"] != json2["age"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) +async def test_required_tool_use_not_yet_supported( + server, client: openai.AsyncOpenAI, guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {TEST_SCHEMA}" + }] + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice="required") + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice="auto") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) +async def test_inconsistent_tool_choice_and_tools( + server, client: openai.AsyncOpenAI, guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {TEST_SCHEMA}" + }] + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tool_choice={ + "type": "function", + "function": { + "name": + "dummy_function_name" + } + }) + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice={ + "type": "function", + "function": { + "name": "nondefined_function_name" + } + }) + + @pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): diff --git a/tests/utils.py b/tests/utils.py index 329842911e159..cc8b862769475 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,7 +24,8 @@ class ServerRunner: env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + [sys.executable, "-m", "vllm.entrypoints.openai.api_server"] + + args, env=env, stdout=sys.stdout, stderr=sys.stderr, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index bbd61a2c5dd59..15bdae38d1d46 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -102,6 +102,26 @@ class ResponseFormat(OpenAIBaseModel): type: Literal["text", "json_object"] +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -122,6 +142,9 @@ class ChatCompletionRequest(OpenAIBaseModel): stream: Optional[bool] = False temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], + ChatCompletionNamedToolChoiceParam]] = "none" user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -245,10 +268,27 @@ class ChatCompletionRequest(OpenAIBaseModel): "guided_regex" in data and data["guided_regex"] is not None, "guided_choice" in data and data["guided_choice"] is not None ]) + # you can only use one kind of guided decoding if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and "tool_choice" in data and data[ + "tool_choice"] != "none": + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_choice(cls, data): + if "tool_choice" in data and data["tool_choice"] != "none": + if not isinstance(data["tool_choice"], dict): + raise ValueError("Currently only named tools are supported.") + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") return data @model_validator(mode="before") @@ -506,9 +546,21 @@ class EmbeddingResponse(BaseModel): usage: UsageInfo +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + class ChatMessage(OpenAIBaseModel): role: str content: str + tool_calls: List[ToolCall] = Field(default_factory=list) class ChatCompletionLogProb(OpenAIBaseModel): @@ -535,7 +587,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: str = "chat.completion" + object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] @@ -545,6 +597,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): @@ -557,7 +610,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: str = "chat.completion.chunk" + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index cc5b896e0e56c..7b52e10952462 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -14,10 +14,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionContentPartParam, ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, - ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - UsageInfo) + FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger @@ -298,11 +299,24 @@ class OpenAIServingChat(OpenAIServing): delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) + + if request.tool_choice and type( + request.tool_choice + ) is ChatCompletionNamedToolChoiceParam: + delta_message = DeltaMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text)) + ]) + else: + delta_message = DeltaMessage(content=delta_text) + if output.finish_reason is None: # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(content=delta_text), + delta=delta_message, logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( @@ -324,7 +338,7 @@ class OpenAIServingChat(OpenAIServing): ) choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(content=delta_text), + delta=delta_message, logprobs=logprobs, finish_reason=output.finish_reason, stop_reason=output.stop_reason) @@ -381,9 +395,22 @@ class OpenAIServingChat(OpenAIServing): else: logprobs = None + if request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + message = ChatMessage( + role=role, + content="", + tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=output.text)) + ]) + elif not request.tool_choice or request.tool_choice == "none": + message = ChatMessage(role=role, content=output.text) + choice_data = ChatCompletionResponseChoice( index=output.index, - message=ChatMessage(role=role, content=output.text), + message=message, logprobs=logprobs, finish_reason=output.finish_reason, stop_reason=output.stop_reason) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 0558d6c95d97b..50aa3ec379f4a 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,7 +1,8 @@ from typing import Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, + CompletionRequest) from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( get_lm_format_enforcer_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_decoding import ( @@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor( guided_decoding_backend: str, request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Optional[LogitsProcessor]: + request = _adapt_request_for_tool_use(request) + if guided_decoding_backend == 'outlines': return await get_outlines_guided_decoding_logits_processor( request, tokenizer) @@ -23,3 +26,26 @@ async def get_guided_decoding_logits_processor( raise ValueError( f"Unknown guided decoding backend '{guided_decoding_backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") + + +def _adapt_request_for_tool_use(request: Union[CompletionRequest, + ChatCompletionRequest]): + # the legacy completion API does not support tool use + if type(request) is CompletionRequest: + return request + + # user has chosen to not use any tool + if request.tool_choice == "none": + return request + + # user has chosen to use a named tool + if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = request.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in request.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + request.guided_json = tool.parameters + + return request