diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 7f90181048d0f..aa4cc7b35a543 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -86,10 +86,6 @@ if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} fi -if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then - commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"} -fi - if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c42ec4f2503d0..fe4796b35786c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -6,24 +6,28 @@ # to generate the final pipeline yaml file. # Documentation -# label(str): the name of the test. emoji allowed. -# fast_check(bool): whether to run this on each commit on fastcheck pipeline. -# torch_nightly(bool): whether to run this on vllm against torch nightly pipeline. -# fast_check_only(bool): run this test on fastcheck pipeline only -# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. +# label(str): the name of the test. emojis allowed. +# fast_check(bool): whether to run this on each commit on the fastcheck pipeline. +# torch_nightly(bool): whether to run this on vllm against the torch nightly pipeline. +# fast_check_only(bool): run this test on the fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's a scheduled nightly run. +# soft_fail(bool): allow this step to fail without failing the entire pipeline (useful for flaky or experimental tests). # command(str): the single command to run for tests. incompatible with commands. -# commands(list): the list of commands to run for test. incompatbile with command. -# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] -# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 -# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. -# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, -# in this case, commands must be specified. the first command runs on first host, the second +# commands(list): the list of commands to run for the test. incompatible with command. +# mirror_hardwares(list): the list of hardware to run the test on as well. currently only supports [amdexperimental] +# gpu(str): override the GPU selection for the test. default is L4 GPUs. supports a100, b200, h200 +# num_gpus(int): override the number of GPUs for the test. defaults to 1 GPU. currently supports 2,4. +# num_nodes(int): whether to simulate multi-node setup by launching multiple containers on one host, +# in this case, commands must be specified. the first command runs on the first host, the second # command runs on the second host. -# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests -# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run. +# timeout_in_minutes(int): sets a timeout for the step in minutes. if not specified, uses the default timeout. +# parallelism(int): number of parallel jobs to run for this step. enables test sharding using $$BUILDKITE_PARALLEL_JOB +# and $$BUILDKITE_PARALLEL_JOB_COUNT environment variables. +# working_dir(str): specify the place where the command should execute, default to /vllm-workspace/tests +# source_file_dependencies(list): the list of prefixes to opt-in the test for, if empty, the test will always run. # When adding a test -# - If the test belong to an existing group, add it there +# - If the test belongs to an existing group, add it there # - If the test is short, add to any existing step # - If the test takes more than 10min, then it is okay to create a new step. # Note that all steps execute in parallel. @@ -110,7 +114,7 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Entrypoints Integration Test (API Server) # 100min timeout_in_minutes: 130 @@ -148,7 +152,6 @@ steps: num_gpus: 4 source_file_dependencies: - vllm/distributed/ - - vllm/core/ - tests/distributed/test_utils - tests/distributed/test_pynccl - tests/distributed/test_events @@ -163,7 +166,6 @@ steps: - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 - - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with tp=2 and pp=2 - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py @@ -314,12 +316,11 @@ steps: - python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py - - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Platform Tests (CUDA) # 4min timeout_in_minutes: 15 @@ -869,8 +870,6 @@ steps: - tests/distributed/ - vllm/compilation - vllm/worker/worker_base.py - - vllm/worker/worker.py - - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py - tests/v1/test_async_llm_dp.py - tests/v1/test_external_lb_dp.py @@ -894,7 +893,7 @@ steps: - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. # TODO: investigate and fix - - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s models/multimodal/generation/test_maverick.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 323675993467a..37bd0ace98a97 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,11 +4,8 @@ # This lists cover the "core" components of vLLM that require careful review /vllm/attention @LucasWilkinson /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn /vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn -/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/fused_moe @mgoin /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py deleted file mode 100644 index 392fba8fc5ead..0000000000000 --- a/examples/offline_inference/profiling.py +++ /dev/null @@ -1,510 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import inspect -import json -import os -import sys -from argparse import RawTextHelpFormatter -from collections.abc import Generator -from dataclasses import asdict, dataclass -from typing import Any, Optional, TypeAlias - -import torch -import tqdm - -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.profiler.layerwise_profile import layerwise_profile -from vllm.utils import FlexibleArgumentParser - -BATCH_SIZE_DEFAULT = 1 -PROMPT_LEN_DEFAULT = 256 - - -@dataclass -class ProfileContext: - engine_args: EngineArgs - prompt_len: int - batch_size: int - - # The profiler can run in 2 modes, - # 1. Run profiler for user specified num_steps - num_steps: Optional[int] = None - # 2. Run profiler until all requests complete - complete_num_requests_per_step: Optional[int] = None - - save_chrome_traces_folder: Optional[str] = None - - -def get_dtype(dtype: str): - if dtype == "torch.float": - return torch.float - else: - return dtype - - -OutputLen_NumReqs_Map: TypeAlias = dict[int, int] - - -def compute_request_output_lengths( - batch_size: int, step_requests: list[int] -) -> OutputLen_NumReqs_Map: - """ - Given the number of requests, batch_size, and the number of requests - that each engine-step should process, step_requests, determine the - output lengths of the requests such that step_request is honoured. - - Example: - if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1] - then return, - {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning, - 32 requests should have output length 2, - 32 requests should have output length 3, - 32 requests should have output length 4, - 31 requests should have output length 5, - 1 request should have output length 6. - - Args: - batch_size (int): Number of requests submitted for profile. This is - args.batch_size. - step_requests (list[int]): step_requests[i] is the number of requests - that the ith engine step should process. - - Returns: - OutputLen_NumReqs_Map : A dictionary with output-length as keys and the - number of requests required to have that output-length as values. - """ - ol_nr: OutputLen_NumReqs_Map = {} - - # Number of request that are assigned an output-length - num_reqs_assigned: int = 0 - num_steps: int = len(step_requests) - - # sanity check. The first step (prefill-step), must process all requests. - assert step_requests[0] == batch_size - - # Begin assignments from the last step. - output_length: int = num_steps - for num_requests_at_step in reversed(step_requests): - if num_reqs_assigned == batch_size: - break - - assert num_reqs_assigned < batch_size - - # Remove the number of requests that have been determined - # to participate in this step and beyond. - num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned - assert num_reqs_unassigned_at_step >= 0 - - if num_reqs_unassigned_at_step > 0: - ol_nr[output_length] = num_reqs_unassigned_at_step - num_reqs_assigned += num_reqs_unassigned_at_step - - output_length -= 1 - - # sanity checks. - assert sum(ol_nr.values()) == batch_size, ( - "Number of requests in output-length assignment does not match " - f"batch-size.\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}" - ) - - # Check that the output-length is in [1, num-steps]. Output length must be - # at least 1 as all requests must participate in the prefill-step. - assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), ( - "Output lengths of requests should be in range " - f"[1, num-engine-steps].\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}" - ) - - return ol_nr - - -def determine_requests_per_step(context: ProfileContext) -> list[int]: - """ - Determine number of requests each engine step should process. - If context.num_steps is set, then all engine steps process the - same number of requests and the output list is of length - context.num_steps. - - If context.complete_num_requests_per_step is set, then each decode step - processes fewer and fewer requests until there are no requests to process. - In this case, the output list is as big as the number of steps - required to process all requests. - - Args: - context: ProfileContext object. - - Returns: - list[int]: Number of requests to process for all engine-steps. - output[i], contains the number of requests that the ith step - should process. - """ - if context.num_steps: - # All requests must run until num_engine_steps. This implies - # that their output lengths must be equal to num_engine_steps. - return [context.batch_size] * context.num_steps - - assert ( - context.complete_num_requests_per_step - and context.complete_num_requests_per_step > 0 - ), ( - f"Expected a positive complete_num_requests_per_step argument." - f"Instead got {context.complete_num_requests_per_step}" - ) - - # We start dropping after the first decode step. - step_requests = [ - context.batch_size, # prefill - context.batch_size, # decode - ] - - num_running_requests = context.batch_size - num_running_requests -= context.complete_num_requests_per_step - while num_running_requests > 0: - step_requests.append(num_running_requests) - num_running_requests -= context.complete_num_requests_per_step - - if step_requests[-1] != 1: - # have 1 request running at the last step. This is often - # useful - step_requests.append(1) - - return step_requests - - -def run_profile( - context: ProfileContext, csv_output: Optional[str], json_output: Optional[str] -): - print("Run profile with:") - for key, value in asdict(context).items(): - print(f" {key} = {value}") - - requests_per_step: list[int] = determine_requests_per_step(context) - - ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( - context.batch_size, requests_per_step - ) - - num_steps_to_profile: int = len(requests_per_step) - max_output_len: int = max(ol_nr.keys()) - assert max_output_len >= 1 - - # Create sampling params - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - # max_tokens is set on a per-request basis. - max_tokens=None, - ignore_eos=True, - ) - - # Create LLM - llm = LLM(**asdict(context.engine_args)) - batch_size = context.batch_size - prompt_len = context.prompt_len - - scheduler_config = llm.llm_engine.vllm_config.scheduler_config - max_model_len = llm.llm_engine.model_config.max_model_len - max_num_batched_tokens = scheduler_config.max_num_batched_tokens - max_num_seqs = scheduler_config.max_num_seqs - - if batch_size * prompt_len > max_num_batched_tokens: - print( - f"ERROR: chosen batch_size * prompt_len " - f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " - f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " - f"and therefore cannot be run in a single profile step, please " - f"choose a smaller batch size or prompt length, or increase " - f"--max-num-batched-tokens" - ) - sys.exit(-1) - if batch_size > max_num_seqs: - print( - f"ERROR: chosen batch_size ({batch_size}) is larger than " - f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " - f"single profile step, please choose a smaller batch size" - ) - sys.exit(-1) - print( - "llm.llm_engine.model_config.max_model_len: ", - llm.llm_engine.model_config.max_model_len, - ) - if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: - print( - f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " - f"{max_output_len} = {prompt_len + max_output_len}) is larger " - f"than the model's max_model_len ({max_model_len}), please " - f"choose a smaller prompt_len or max_output_len, or increase " - f"--max-model-len" - ) - sys.exit(-1) - - def add_requests(): - def get_output_len_generator() -> Generator[int, Any, Any]: - for output_len, num_reqs in ol_nr.items(): - for _ in range(num_reqs): - yield output_len - - output_len_generator = get_output_len_generator() - for i in range(batch_size): - sampling_params.max_tokens = next(output_len_generator) - assert isinstance(sampling_params.max_tokens, int) - - prompt_token_ids = torch.randint( - llm.get_tokenizer().vocab_size, size=(prompt_len,) - ).tolist() - - llm.llm_engine.add_request( - request_id=f"seq{i}", - prompt={"prompt_token_ids": prompt_token_ids}, - params=sampling_params, - ) - - def abort_requests(): - for i in range(batch_size): - llm.llm_engine.abort_request(f"seq{i}") - - # Warm up run - print("Warm up run ...") - add_requests() - llm.llm_engine.step() # Prefill - llm.llm_engine.step() # Decode - abort_requests() - - print("Profile run ...") - add_requests() - - with layerwise_profile() as prefill_prof: - llm.llm_engine.step() # First step is prefill - - decode_profs = [] - for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): - num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups() - with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof: - llm.llm_engine.step() - decode_profs.append(decode_prof) - - decode_results_list = [prof.results for prof in decode_profs] - prefill_results = prefill_prof.results - has_decode = len(decode_results_list) > 0 - - LINE_WIDTH = 80 - print("=" * LINE_WIDTH) - print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * LINE_WIDTH) - print() - prefill_results.print_model_table() - - if has_decode: - print() - print("=" * LINE_WIDTH) - print( - f"= First Decode Step Model Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})" - ) - print("=" * LINE_WIDTH) - print() - decode_results_list[0].print_model_table() - - print() - print("=" * LINE_WIDTH) - print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * LINE_WIDTH) - print() - prefill_results.print_summary_table() - - if has_decode: - print() - print("=" * LINE_WIDTH) - print( - f"= First Decode Step Summary Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})" - ) - print("=" * LINE_WIDTH) - print() - decode_results_list[0].print_summary_table() - - if csv_output: - csv_filename_base = ( - csv_output[:-4] if csv_output.endswith(".csv") else csv_output - ) - prefill_results.export_model_stats_table_csv( - csv_filename_base + "_prefill_model_table.csv" - ) - prefill_results.export_summary_stats_table_csv( - csv_filename_base + "_prefill_summary_table.csv" - ) - - if has_decode: - decode_results_list[0].export_model_stats_table_csv( - csv_filename_base + "_decode_model_table.csv" - ) - decode_results_list[0].export_summary_stats_table_csv( - csv_filename_base + "_decode_summary_table.csv" - ) - - if json_output: - cuda_devices = [ - torch.cuda.get_device_properties(dev_idx) - for dev_idx in range(torch.cuda.device_count()) - ] - - json_dict = { - "context": { - "python_version": f"{sys.version}", - "torch_version": f"{torch.__version__}", - "torch_cuda_version": f"{torch.version.cuda}", - "cuda_devices": f"{cuda_devices}", - **asdict(context), - }, - "prefill": prefill_results.convert_stats_to_dict(), - } - - if has_decode: - for idx, dr in enumerate(decode_results_list): - json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - - # Add .json to json_output filename if it doesn't exist already. - json_output_file = ( - json_output if json_output.endswith(".json") else json_output + ".json" - ) - with open(json_output_file, "w+") as f: - json.dump(json_dict, f, indent=2) - pass - - if context.save_chrome_traces_folder is not None: - os.makedirs(context.save_chrome_traces_folder, exist_ok=True) - prefill_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + "/prefill.json" - ) - for idx, decode_prof in enumerate(decode_profs): - decode_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + f"/decode_{idx + 1}.json" - ) - print( - "Traces saved as prefill.json and decode_1.json, etc." - f" in folder {context.save_chrome_traces_folder}" - ) - - -def parse_args(): - parser = FlexibleArgumentParser( - description=""" -Profile a model - - example: - ``` - python examples/offline_inference/profiling.py \\ - --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ - --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ - --enforce-eager run_num_steps -n 2 - ``` - - then you can use various tools to analyze the json output - terminal ascii tables: - ``` - python tools/profiler/print_layerwise_table.py \\ - --json-trace Llama31-8b-FP8.json --phase prefill --table summary - ``` - or create matplotlib stacked bar charts: - ``` - python tools/profiler/visualize_layerwise_profile.py \\ - --json-trace Llama31-8b-FP8.json \\ - --output-directory profile_breakdown --plot-metric pct_cuda_time - ``` -""", - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--csv", - type=str, - default=None, - help="Export the results as multiple csv file. This should be the root " - "filename, will create _prefill_model_table.csv, " - "_prefill_summary_table.csv, " - "_decode_model_table.csv, and " - "_decode_summary_table.csv", - ) - parser.add_argument( - "--json", - type=str, - default=None, - help="Export the results as a json file. This should be the filename", - ) - parser.add_argument( - "--save-chrome-traces-folder", - type=str, - help="Save chrome traces for the prefill and decode " - "will save traces as prefill.json and decode_1.json, " - "etc. inside this folder", - ) - parser.add_argument( - "--prompt-len", - type=int, - default=PROMPT_LEN_DEFAULT, - help=f"Length of the random prompt to use when profiling, all batched " - f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}", - ) - parser.add_argument( - "--batch-size", - type=int, - default=BATCH_SIZE_DEFAULT, - help=f"Number of requests to run as a single batch, " - f"default={BATCH_SIZE_DEFAULT}", - ) - - subparsers = parser.add_subparsers(dest="cmd") - - run_num_steps_parser = subparsers.add_parser( - "run_num_steps", help="This variation profiles n engine.step() invocations." - ) - run_num_steps_parser.add_argument( - "-n", - "--num-steps", - type=int, - help="Number of engine steps to profile.\n" - "Setting it to 1, profiles only the prefill step.\n" - "Setting it to 2, profiles the prefill and first decode step\n" - "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" - "and so on ...", - ) - - run_to_completion_parser = subparsers.add_parser( - "run_to_completion", - help="This variation profiles all the engine.step() invocations" - "until the engine exhausts all submitted requests.", - ) - run_to_completion_parser.add_argument( - "-n", - "--complete-num-requests-per-step", - type=int, - help="Complete complete_num_requests_per_step requests every decode step." - "For e.g., with batch_size 128 and complete_num_requests_per_step 32," - "the profiler is run for 6 engine steps, with the steps processing, " - "128, 128, 96, 64, 32, 1 requests respectively.\n" - "Note that we tack-on a one-request step at the end as it is often " - "useful.", - ) - - EngineArgs.add_cli_args(parser) - - return parser.parse_args() - - -def main(args): - context = ProfileContext( - engine_args=EngineArgs.from_cli_args(args), - **{ - k: v - for k, v in vars(args).items() - if k in inspect.signature(ProfileContext).parameters - }, - ) - run_profile(context, csv_output=args.csv, json_output=args.json) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/mkdocs.yaml b/mkdocs.yaml index 6f2be65a18af8..1535fcc622cdc 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -102,6 +102,7 @@ plugins: - https://numpy.org/doc/stable/objects.inv - https://pytorch.org/docs/stable/objects.inv - https://psutil.readthedocs.io/en/stable/objects.inv + - https://huggingface.co/docs/transformers/main/en/objects.inv markdown_extensions: - attr_list diff --git a/pyproject.toml b/pyproject.toml index fe55461db00be..f43ae69e00bdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ line-length = 80 "vllm/_version.py" = ["ALL"] # Python 3.8 typing - skip V0 code "vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] @@ -117,7 +116,6 @@ files = [ "vllm/*.py", "vllm/assets", "vllm/entrypoints", - "vllm/core", "vllm/inputs", "vllm/logging_utils", "vllm/multimodal", diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 24b1c9a93126c..411f3e01bc2cd 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -11,7 +11,7 @@ from unittest.mock import Mock import pytest import torch -from vllm import LLM, envs +from vllm import LLM from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import HfRunner, VllmRunner @@ -26,14 +26,6 @@ MODELS = [ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" llm = LLM("distilbert/distilgpt2") @@ -76,12 +68,6 @@ def test_models( model_executor: str, enable_prompt_embeds: bool, ) -> None: - if not envs.VLLM_USE_V1: - if async_scheduling: - pytest.skip("async_scheduling only supported in v1.") - if model_executor != "uni": - pytest.skip("only test uniproc executor for v0.") - if backend == "XFORMERS" and model == "google/gemma-2-2b-it": pytest.skip( f"{backend} does not support gemma2 with full context length.") diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index f3ad680b72b55..508740ab29389 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -122,11 +122,12 @@ def test_cumem_with_cudagraph(): # sleep mode with safetensors ("meta-llama/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint - ("facebook/opt-125m", False), + ("facebook/opt-125m", True), ]) def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + assert use_v1 + m.setenv("VLLM_USE_V1", "1") free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running llm = LLM(model, enable_sleep_mode=True) diff --git a/tests/build_cython.py b/tests/build_cython.py deleted file mode 100644 index 444434e8f0a79..0000000000000 --- a/tests/build_cython.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import Cython.Compiler.Options -from Cython.Build import cythonize -from setuptools import setup - -Cython.Compiler.Options.annotate = True - -infiles = [] - -infiles += [ - "vllm/engine/llm_engine.py", - "vllm/transformers_utils/detokenizer.py", - "vllm/engine/output_processor/single_step.py", - "vllm/outputs.py", - "vllm/engine/output_processor/stop_checker.py", -] - -infiles += [ - "vllm/core/scheduler.py", - "vllm/sequence.py", - "vllm/core/block_manager.py", -] - -infiles += [ - "vllm/model_executor/layers/sampler.py", - "vllm/sampling_params.py", - "vllm/utils/__init__.py", -] - -setup(ext_modules=cythonize(infiles, - annotate=False, - force=True, - compiler_directives={ - 'language_level': "3", - 'infer_types': True - })) - -# example usage: python3 build_cython.py build_ext --inplace diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 022f183b31932..b6bebbba915be 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -54,8 +54,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, # Use global backends global backend, backend_unfused - use_v1 = False # can be made a param once V1 support added - monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1))) + monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa))) # Prompt 4 seems too open-ended, differs between fused and unfused diff --git a/tests/conftest.py b/tests/conftest.py index e8e95357ff5b9..dc70c98359598 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ import socket import tempfile import threading from collections.abc import Generator +from contextlib import nullcontext from enum import Enum from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast @@ -45,14 +46,14 @@ from vllm.connections import global_http_connection from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.utils import set_default_torch_num_threads logger = init_logger(__name__) @@ -159,26 +160,6 @@ def cleanup_VLLM_USE_V1(monkeypatch): monkeypatch.delenv("VLLM_USE_V1") -@pytest.fixture(params=[True, False]) -def run_with_both_engines(request, monkeypatch): - # Automatically runs tests twice, once with V1 and once without - use_v1 = request.param - # Tests decorated with `@skip_v1` are only run without v1 - skip_v0 = request.node.get_closest_marker("skip_v0") - skip_v1 = request.node.get_closest_marker("skip_v1") - - if use_v1: - if skip_v1: - pytest.skip("Skipping test on vllm V1") - monkeypatch.setenv('VLLM_USE_V1', '1') - else: - if skip_v0: - pytest.skip("Skipping test on vllm V0") - monkeypatch.setenv('VLLM_USE_V1', '0') - - yield - - @pytest.fixture(autouse=True) def init_test_http_connection(): # pytest_asyncio may use a different event loop per test @@ -306,6 +287,35 @@ class HfRunner: is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, + # Set this to avoid hanging issue + default_torch_num_threads: Optional[int] = None, + ) -> None: + init_ctx = (nullcontext() if default_torch_num_threads is None else + set_default_torch_num_threads(default_torch_num_threads)) + + with init_ctx: + self._init( + model_name=model_name, + dtype=dtype, + model_kwargs=model_kwargs, + trust_remote_code=trust_remote_code, + is_sentence_transformer=is_sentence_transformer, + is_cross_encoder=is_cross_encoder, + skip_tokenizer_init=skip_tokenizer_init, + auto_cls=auto_cls, + ) + + def _init( + self, + model_name: str, + dtype: str = "auto", + *, + model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, + is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, + skip_tokenizer_init: bool = False, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, ) -> None: model_name = maybe_model_redirect(model_name) self.model_name = model_name @@ -714,26 +724,32 @@ class VllmRunner: enable_chunked_prefill: Optional[bool] = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, + # Set this to avoid hanging issue + default_torch_num_threads: Optional[int] = None, **kwargs, ) -> None: - self.llm = LLM( - model=model_name, - runner=runner, - convert=convert, - tokenizer=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - dtype=dtype, - seed=seed, - swap_space=swap_space, - enforce_eager=enforce_eager, - disable_log_stats=disable_log_stats, - tensor_parallel_size=tensor_parallel_size, - max_model_len=max_model_len, - block_size=block_size, - enable_chunked_prefill=enable_chunked_prefill, - **kwargs, - ) + init_ctx = (nullcontext() if default_torch_num_threads is None else + set_default_torch_num_threads(default_torch_num_threads)) + + with init_ctx: + self.llm = LLM( + model=model_name, + runner=runner, + convert=convert, + tokenizer=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + dtype=dtype, + seed=seed, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) def get_inputs( self, diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index cb87c44cc3999..46f7d58c438cb 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -32,10 +32,6 @@ def _test_stopping(llm: LLM, assert output.stop_reason == expected_reason -def _set_async_mode(llm, is_async): - llm.llm_engine.scheduler[0].use_async_output_proc = is_async - - def _stop_basic(llm): _test_stopping(llm, stop=["."], @@ -103,40 +99,8 @@ def test_stop_strings(): # async output processing below. llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) - if envs.VLLM_USE_V1: - _stop_basic(llm) - else: - _set_async_mode(llm, True) - _stop_basic(llm) - - _set_async_mode(llm, False) - _stop_basic(llm) - - if envs.VLLM_USE_V1: - _stop_multi_tokens(llm) - else: - _set_async_mode(llm, True) - _stop_multi_tokens(llm) - - _set_async_mode(llm, False) - _stop_multi_tokens(llm) - - if envs.VLLM_USE_V1: - _stop_partial_token(llm) - else: - _set_async_mode(llm, True) - _stop_partial_token(llm) - - _set_async_mode(llm, False) - _stop_partial_token(llm) - - if envs.VLLM_USE_V1: - # FIXME: this does not respect include_in_output=False - # _stop_token_id(llm) - pass - else: - _set_async_mode(llm, True) - _stop_token_id(llm) - - _set_async_mode(llm, False) - _stop_token_id(llm) + _stop_basic(llm) + _stop_multi_tokens(llm) + _stop_partial_token(llm) + # FIXME: this does not respect include_in_output=False + # _stop_token_id(llm) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 3bbbcc755d134..e0ecb02d4f563 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -25,12 +25,6 @@ TOKEN_IDS = [ ] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 1b7be15d5d691..b219b33d1760e 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -6,14 +6,6 @@ import pytest from vllm import LLM -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass - - def test_empty_prompt(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) with pytest.raises(ValueError, match='decoder prompt cannot be empty'): diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 176c1825530e4..9c62595ad280b 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -60,6 +60,7 @@ def create_dummy_embeds(num_tokens: int = 5) -> str: return base64.b64encode(buffer.getvalue()).decode('utf-8') +@pytest.mark.skip("This test is skipped because it is flaky.") @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_completions_with_prompt_embeds( diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 8917aa5a5efb9..f0b61902eb568 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -432,7 +432,7 @@ def test_metrics_exist_run_batch(use_v1: bool): "--port", port, ], - env={"VLLM_USE_V1": "1" if use_v1 else "0"}) + env={"VLLM_USE_V1": "1"}) def is_server_up(url): try: diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 190c92e1251c2..f8454ad0a4c4d 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -69,28 +69,20 @@ def generate_params(): @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) -@pytest.mark.parametrize("use_v1", [True, False]) def test_env( device: str, name: str, use_mla: bool, block_size: int, - use_v1: bool, monkeypatch: pytest.MonkeyPatch, ): """Test attention backend selection with valid device-backend pairs.""" with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv("VLLM_USE_V1", "1") m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") - if name == "FLASHINFER" and not use_v1: - pytest.skip("FlashInfer backend is only available on V1 engine") - if device == "cpu": - if not use_v1: - pytest.skip("CPU backend only supports V1") - with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, None, block_size, @@ -137,7 +129,7 @@ def test_env( block_size, False, use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + expected = f"{name}_VLLM_V1" assert backend.get_name() == expected else: backend = get_attn_backend(16, @@ -146,7 +138,7 @@ def test_env( block_size, False, use_mla=use_mla) - expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" + expected = "TRITON_ATTN_VLLM_V1" assert backend.get_name() == expected elif device == "cuda": @@ -163,11 +155,7 @@ def test_env( # - TRITON_MLA: fallback for other cases if name == "CUTLASS_MLA": - if not use_v1: - # CUTLASS_MLA only supported on V1 engine - pytest.skip( - "CUTLASS_MLA only supported on V1 engine") - elif block_size != 128: + if block_size != 128: # CUTLASS_MLA only supports block_size == 128 pytest.skip( "CUTLASS_MLA only supports block_size 128") @@ -181,11 +169,7 @@ def test_env( expected = "CUTLASS_MLA_VLLM_V1" assert backend.get_name() == expected elif name == "FLASHINFER_MLA": - if not use_v1: - # FlashInfer MLA only supported on V1 engine - pytest.skip( - "FlashInfer MLA only supported on V1 engine") - elif block_size not in [32, 64]: + if block_size not in [32, 64]: # FlashInfer MLA only supports block_size 32 or 64 pytest.skip( "FlashInfer MLA only supports block_size 32 " @@ -217,23 +201,17 @@ def test_env( block_size, False, use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + expected = f"{name}_VLLM_V1" assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": - if not use_v1: - # FlashAttention MLA only supported on V1 engine - pytest.skip( - "FlashAttention MLA only supported on V1 engine" - ) - else: - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - False, - use_mla=use_mla) - expected = "FLASH_ATTN_MLA" - assert backend.get_name() == expected + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = "FLASH_ATTN_MLA" + assert backend.get_name() == expected else: # TRITON_MLA or other fallback backend = get_attn_backend(16, @@ -242,8 +220,7 @@ def test_env( block_size, False, use_mla=use_mla) - expected = ("TRITON_MLA_VLLM_V1" - if use_v1 else "TRITON_MLA") + expected = "TRITON_MLA_VLLM_V1" assert backend.get_name() == expected elif name == "FLASHINFER": backend = get_attn_backend(16, @@ -252,7 +229,7 @@ def test_env( block_size, False, use_mla=use_mla) - expected = "FLASHINFER_VLLM_V1" if use_v1 else name + expected = "FLASHINFER_VLLM_V1" assert backend.get_name() == expected else: backend = get_attn_backend(32, @@ -261,36 +238,30 @@ def test_env( block_size, False, use_mla=use_mla) - expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name + expected = "FLASH_ATTN_VLLM_V1" assert backend.get_name() == expected - if use_v1: - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - False, - use_mla=use_mla) - assert backend.get_name() == "FLEX_ATTENTION", ( - "Should fallback to FlexAttention if head size is " - "not supported by FlashAttention") + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + assert backend.get_name() == "FLEX_ATTENTION", ( + "Should fallback to FlexAttention if head size is " + "not supported by FlashAttention") @pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("use_v1", [True, False]) def test_fp32_fallback( device: str, - use_v1: bool, monkeypatch: pytest.MonkeyPatch, ): """Test attention backend selection with fp32.""" with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv("VLLM_USE_V1", "1") if device == "cpu": - if not use_v1: - pytest.skip("CPU backend only supports V1") - with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float32, None, 16, False) @@ -300,8 +271,7 @@ def test_fp32_fallback( with patch("vllm.attention.selector.current_platform", CudaPlatform()): backend = get_attn_backend(16, torch.float32, None, 16, False) - assert (backend.get_name() == "FLEX_ATTENTION" - if use_v1 else "XFORMERS") + assert backend.get_name() == "FLEX_ATTENTION" def test_flash_attn(monkeypatch: pytest.MonkeyPatch): @@ -357,12 +327,11 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): assert backend.get_name() != STR_FLASH_ATTN_VAL -@pytest.mark.parametrize("use_v1", [True, False]) -def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): +def test_invalid_env(monkeypatch: pytest.MonkeyPatch): """Test that invalid attention backend names raise ValueError.""" with monkeypatch.context() as m, patch( "vllm.attention.selector.current_platform", CudaPlatform()): - m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + m.setenv("VLLM_USE_V1", "1") m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) # Should raise ValueError for invalid backend diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 3475993ff8f07..b539a7bf5d76c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -17,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -97,7 +96,6 @@ def dummy_model() -> nn.Module: # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), ("logits_processor", LogitsProcessor(512)), - ("sampler", Sampler()) ])) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} @@ -125,7 +123,6 @@ def dummy_model_gate_up() -> nn.Module: # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), ("logits_processor", LogitsProcessor(512)), - ("sampler", Sampler()) ])) model.config = MagicMock() model.packed_modules_mapping = { diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 50c60341f0d88..221d5237823ca 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -6,10 +6,10 @@ Script to test add_lora, remove_lora, pin_lora, list_loras functions. import pytest from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.lora.request import LoRARequest +from vllm.v1.engine.llm_engine import LLMEngine MODEL_PATH = "meta-llama/Llama-2-7b-hf" LORA_MODULE_PATH = "yard1/llama-2-7b-sql-lora-test" diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index c14e71cbdb96d..39c4dd735b725 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -15,7 +15,8 @@ from ...utils import check_logprobs_close # have a clean way to fall back, so we fail with # a clear msg when it happens. # https://github.com/vllm-project/vllm/issues/14524 -REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] +# NOTE(woosuk): Skipping these tests until V1 supports them. +# REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] # This list contains the model that are using AITER kernel. # Skip model that are not using AITER tests. @@ -113,9 +114,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - if model in REQUIRES_V0: - monkeypatch.setenv("VLLM_USE_V1", "0") - if use_rocm_aiter and (model in AITER_MODEL_LIST): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") elif use_rocm_aiter and model not in AITER_MODEL_LIST: diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 206ad1352e06e..0b1f90e27db82 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -8,7 +8,7 @@ from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams -from ...utils import check_logprobs_close, check_outputs_equal +from ...utils import check_logprobs_close # Mark all tests as hybrid pytestmark = pytest.mark.hybrid_model @@ -88,15 +88,6 @@ def test_models( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - if model not in V0_UNSUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v0_outputs = None - if model in V1_SUPPORTED_MODELS: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( @@ -104,14 +95,6 @@ def test_models( else: vllm_v1_outputs = None - if vllm_v0_outputs is not None: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", - ) - if model in V1_SUPPORTED_MODELS: check_logprobs_close( outputs_0_lst=hf_outputs, @@ -157,45 +140,6 @@ def test_batching( ) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -def test_chunked_prefill( - vllm_runner, - example_prompts, - model: str, - max_tokens: int, - num_logprobs: int, - chunked_prefill_token_size: int, - monkeypatch, -) -> None: - max_num_seqs = chunked_prefill_token_size - max_num_batched_tokens = chunked_prefill_token_size - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - with vllm_runner(model, - enable_chunked_prefill=True, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - with vllm_runner(model, - enable_chunked_prefill=False, - max_num_seqs=max_num_seqs) as vllm_model: - non_chunked = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - check_logprobs_close( - outputs_0_lst=chunked, - outputs_1_lst=non_chunked, - name_0="chunked", - name_1="non_chunked", - ) - - @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [10]) def test_chunked_prefill_with_parallel_sampling( @@ -257,38 +201,6 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) -@pytest.mark.parametrize("max_tokens", [20]) -def test_models_preemption_recompute( - vllm_runner, - example_prompts, - model: str, - max_tokens: int, - monkeypatch, -) -> None: - """ - Tests that outputs are identical with and w/o preemptions (recompute). - """ - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - scheduler = vllm_model.llm.llm_engine.scheduler[0] - scheduler.ENABLE_ARTIFICIAL_PREEMPT = True - preempt_vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) - - scheduler.ENABLE_ARTIFICIAL_PREEMPT = False - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) - - check_outputs_equal( - outputs_0_lst=preempt_vllm_outputs, - outputs_1_lst=vllm_outputs, - name_0="vllm_preepmtions", - name_1="vllm", - ) - - @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( vllm_runner, @@ -386,27 +298,10 @@ def test_full_cuda_graph( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - if model not in V0_UNSUPPORTED_MODELS: - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - else: - vllm_v0_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - if vllm_v0_outputs is not None: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", - ) - check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v1_outputs, @@ -442,27 +337,12 @@ def test_fp32_cache_state( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - **{cache_dtype_param: "float32"}) as vllm_model: - vllm_v0_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", - ) - check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v1_outputs, diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 08722ac98b7ed..4ac91b5aed506 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import pytest import torch @@ -82,7 +81,7 @@ def test_prm_models( check_transformers_version("Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2") - if current_platform.is_cpu() and os.environ.get("VLLM_USE_V1", "0") == "0": + if current_platform.is_cpu(): pytest.skip("CPU only supports V1") if current_platform.is_rocm(): diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 79f9d607f3386..e76b58e61ec16 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -32,13 +32,6 @@ from .vlm_utils.types import (CustomTestOptions, ExpandableVLMTestArgs, if current_platform.is_rocm(): os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" -REQUIRES_V0_MODELS = [ - # V1 Test: not enough KV cache space in C1. - "fuyu", - # V1 Test: Deadlock issue when processing mm_inputs - "llava-onevision-transformers", -] - # yapf: disable COMMON_BROADCAST_SETTINGS = { "test_type": VLMTestType.IMAGE, @@ -186,8 +179,11 @@ VLM_TEST_SETTINGS = { image_size_factors=[(0.25, 0.5, 1.0)], vllm_runner_kwargs={ "model_impl": "transformers", + "default_torch_num_threads": 1, }, - marks=[pytest.mark.core_model], + # FIXME: Investigate why the test hangs + # when processing the 3rd prompt in vLLM + marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")], ), "idefics3-transformers": VLMTestInfo( models=["HuggingFaceTB/SmolVLM-256M-Instruct"], @@ -320,6 +316,7 @@ VLM_TEST_SETTINGS = { vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[large_gpu_mark(min_gb=32)], ), "gemma3": VLMTestInfo( models=["google/gemma-3-4b-it"], @@ -861,13 +858,14 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2) test_type=VLMTestType.IMAGE, create_new_process_for_each_test=False, )) -def test_single_image_models(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_single_image_models( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_single_image_test( tmp_path=tmp_path, @@ -886,13 +884,14 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=False, )) -def test_multi_image_models(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_multi_image_models( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_multi_image_test( tmp_path=tmp_path, @@ -911,13 +910,13 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=False, )) -def test_image_embedding_models(model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_image_embedding_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_embedding_test( model_test_info=model_test_info, @@ -935,11 +934,13 @@ def test_image_embedding_models(model_type: str, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=False, )) -def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: VideoTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_video_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + video_assets: VideoTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_video_test( model_test_info=model_test_info, @@ -957,11 +958,13 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=False, )) -def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - audio_assets: AudioTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_audio_models( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_audio_test( model_test_info=model_test_info, @@ -984,10 +987,7 @@ def test_custom_inputs_models( test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - monkeypatch, ): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_custom_inputs_test( model_test_info=model_test_info, @@ -1006,13 +1006,14 @@ def test_custom_inputs_models( create_new_process_for_each_test=True, )) @create_new_process_for_each_test() -def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_single_image_models_heavy( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_single_image_test( tmp_path=tmp_path, @@ -1032,13 +1033,14 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, create_new_process_for_each_test=True, )) @create_new_process_for_each_test() -def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_multi_image_models_heavy( + tmp_path: PosixPath, + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_multi_image_test( tmp_path=tmp_path, @@ -1058,14 +1060,13 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, create_new_process_for_each_test=True, )) @create_new_process_for_each_test() -def test_image_embedding_models_heavy(model_type: str, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_image_embedding_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_embedding_test( model_test_info=model_test_info, @@ -1083,12 +1084,13 @@ def test_image_embedding_models_heavy(model_type: str, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=True, )) -def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - video_assets: VideoTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_video_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + video_assets: VideoTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_video_test( model_test_info=model_test_info, @@ -1106,12 +1108,13 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=True, )) -def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - audio_assets: AudioTestAssets, monkeypatch): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") +def test_audio_models_heavy( + model_type: str, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_audio_test( model_test_info=model_test_info, @@ -1135,10 +1138,7 @@ def test_custom_inputs_models_heavy( test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - monkeypatch, ): - if model_type in REQUIRES_V0_MODELS: - monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] runners.run_custom_inputs_test( model_test_info=model_test_info, diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index f2e6fbfad6e80..c1305e0ae31ce 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -7,8 +7,8 @@ from typing import Optional import pytest from transformers import AutoModelForSpeechSeq2Seq +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest -from vllm.sequence import SampleLogprobs from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner) diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 67d35213d6422..77e2b90dd5e96 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer from vllm.assets.image import ImageAsset +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, PromptImageInput, VllmRunner) diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index a4e21aface41f..715b08ef90e54 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -12,13 +12,12 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from transformers import AutoProcessor -from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt +from vllm import SamplingParams, TextPrompt, TokensPrompt +from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins -from vllm.multimodal.inputs import PlaceholderRange -from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test -from ...utils import check_logprobs_close, dummy_hf_overrides +from ...utils import check_logprobs_close if TYPE_CHECKING: from _typeshed import StrPath @@ -185,47 +184,3 @@ def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str, outputs_1_lst=logprobs, name_0="h100_ref", name_1="output") - - -@pytest.mark.parametrize( - "image_urls,expected_ranges", - [(IMG_URLS[:1], [PlaceholderRange(offset=11, length=494)]), - (IMG_URLS[1:4], [ - PlaceholderRange(offset=11, length=266), - PlaceholderRange(offset=277, length=1056), - PlaceholderRange(offset=1333, length=418) - ])]) -def test_multi_modal_placeholders(vllm_runner, image_urls: list[str], - expected_ranges: list[PlaceholderRange], - local_asset_server, monkeypatch) -> None: - local_image_urls = [local_asset_server.url_for(u) for u in image_urls] - prompt = _create_engine_inputs_hf(local_image_urls) - - # This placeholder checking test only works with V0 engine - # where `multi_modal_placeholders` is returned with `RequestOutput` - monkeypatch.setenv("VLLM_USE_V1", "0") - with vllm_runner( - "mistral-community/pixtral-12b", - max_model_len=8192, - limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, - load_format="dummy", - hf_overrides=dummy_hf_overrides, - ) as vllm_model: - outputs = vllm_model.llm.generate(prompt) - - assert len(outputs) == 1, f"{len(outputs)=}" - output: RequestOutput = outputs[0] - assert hasattr(output, - "multi_modal_placeholders"), f"{output.__dict__=}" - assert "image" in output.multi_modal_placeholders, \ - f"{output.multi_modal_placeholders.keys()=}" - image_placeholder_ranges: list[ - PlaceholderRange] = output.multi_modal_placeholders["image"] - assert len(image_placeholder_ranges) == len( - expected_ranges), f"{image_placeholder_ranges=}" - for real_range, expected_range in zip(image_placeholder_ranges, - expected_ranges): - assert real_range.offset == expected_range.offset, \ - f"{real_range=} {expected_range=}" - assert real_range.length == expected_range.length, \ - f"{real_range=} {expected_range=}" diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index e56f4e4075be4..8336ebc0d59cb 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -10,7 +10,6 @@ from PIL import Image from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import rescale_video_size, sample_frames_from_video -from vllm.utils import set_default_torch_num_threads from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, PromptVideoInput, VllmRunner) @@ -264,8 +263,7 @@ def run_embedding_input_test( processor = AutoProcessor.from_pretrained(model) # max_model_len should be greater than image_feature_size - with set_default_torch_num_threads(1): - vllm_model = vllm_runner( + with vllm_runner( model, runner="generate", max_model_len=4000, @@ -277,9 +275,8 @@ def run_embedding_input_test( }, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - ) - - with vllm_model: + default_torch_num_threads=1, + ) as vllm_model: outputs_per_case_for_original_input = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 8b7d051218f14..ba55450ec8a90 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, GenerationConfig, GenerationMixin) from transformers.video_utils import VideoMetadata -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 9451131960885..e39ca40fbbf5e 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import RunnerOption -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index b503d42567022..7309660ea5261 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -4,8 +4,6 @@ import pytest import torch -from vllm.utils import set_default_torch_num_threads - from ....conftest import VllmRunner @@ -30,19 +28,17 @@ def _run_test( } for _ in range(10) ] - with ( - set_default_torch_num_threads(1), - vllm_runner( - model, - runner="pooling", - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True, - # Limit the maximum number of sequences to avoid the - # test going OOM during the warmup run - max_num_seqs=32, - ) as vllm_model, - ): + with vllm_runner( + model, + runner="pooling", + dtype="half", + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + default_torch_num_threads=1, + ) as vllm_model: vllm_model.encode(prompt) diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index 7005e435ecf46..e741e4ad90a09 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -45,12 +45,15 @@ def run_awq_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size - with vllm_runner(source_model, - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner( + source_model, + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + default_torch_num_threads=1, + ) as vllm_model: source_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, @@ -59,13 +62,16 @@ def run_awq_test( for prompts, images in inputs_per_image ] - with vllm_runner(quant_model, - quantization="awq", - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner( + quant_model, + quantization="awq", + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + default_torch_num_threads=1, + ) as vllm_model: quant_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, @@ -108,12 +114,8 @@ def run_awq_test( @pytest.mark.parametrize("num_logprobs", [5]) @torch.inference_mode() def test_awq_models(vllm_runner, image_assets, source_model, quant_model, - size_factors, dtype, max_tokens, num_logprobs, - monkeypatch) -> None: + size_factors, dtype, max_tokens, num_logprobs) -> None: - # Test V1: this test hangs during setup on single-scale input. - # TODO: figure out why and re-enable this on V1. - monkeypatch.setenv("VLLM_USE_V1", "0") run_awq_test( vllm_runner, image_assets, diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py index e0e919b62b217..25fc44fee90dd 100644 --- a/tests/models/quantization/test_bitsandbytes.py +++ b/tests/models/quantization/test_bitsandbytes.py @@ -5,10 +5,7 @@ Run `pytest tests/quantization/test_bitsandbytes.py`. ''' -import gc - import pytest -import torch from transformers import BitsAndBytesConfig from tests.quantization.utils import is_quant_method_supported @@ -131,12 +128,15 @@ def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, )) with vllm_runner(model_name, quantization='bitsandbytes', - enforce_eager=False) as llm: + enforce_eager=False, + default_torch_num_threads=1) as llm: vllm_outputs = llm.generate_greedy_logprobs(example_prompts, max_tokens=32, num_logprobs=5) - with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + with hf_runner(model_name, + model_kwargs=hf_model_kwargs, + default_torch_num_threads=1) as llm: transformers_outputs = llm.generate_greedy_logprobs_limit( example_prompts, max_tokens=32, num_logprobs=5) check_logprobs_close( @@ -174,7 +174,8 @@ def test_4bit_bnb_embedding_model( runner="pooling", dtype=dtype, gpu_memory_utilization=0.5, - quantization="bitsandbytes") as vllm_model: + quantization="bitsandbytes", + default_torch_num_threads=1) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( @@ -184,6 +185,7 @@ def test_4bit_bnb_embedding_model( dtype=dtype, model_kwargs=hf_model_kwargs, is_sentence_transformer=True, + default_torch_num_threads=1, ) as hf_model: hf_outputs = hf_model.encode(example_prompts) @@ -222,26 +224,22 @@ def validate_generated_texts(hf_runner, with vllm_runner(model_name, quantization=None if pre_quant else 'bitsandbytes', tensor_parallel_size=vllm_tp_size, - enforce_eager=False) as llm: + enforce_eager=False, + default_torch_num_threads=1) as llm: vllm_outputs = llm.generate_greedy(prompts, max_tokens) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") - # Clean up the GPU memory for the next test - gc.collect() - torch.cuda.empty_cache() - if hf_model_kwargs is None: hf_model_kwargs = {} # Run with HF runner - with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + with hf_runner(model_name, + model_kwargs=hf_model_kwargs, + default_torch_num_threads=1) as llm: hf_outputs = llm.generate_greedy(prompts, max_tokens) hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") - # Clean up the GPU memory for the next test - gc.collect() - torch.cuda.empty_cache() # Compare the generated strings for hf_log, vllm_log in zip(hf_logs, vllm_logs): hf_str = hf_log["generated_text"] diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index afc27b6e0566e..bb8ae741b6149 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -32,13 +32,10 @@ from ..utils import check_logprobs_close # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) -# Due to low-precision numerical divergence, this test is too sensitive for -# the async postprocessor -@pytest.mark.parametrize("disable_async_output_proc", [True]) def test_models( vllm_runner, example_prompts, @@ -49,7 +46,6 @@ def test_models( enforce_eager: bool, backend: str, tensor_parallel_size: int, - disable_async_output_proc: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: """ @@ -61,6 +57,9 @@ def test_models( pytest.skip( f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") + if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None): + pytest.skip(f"{kv_cache_dtype} is not supported on this platform.") + with monkeypatch.context() as m: m.setenv("TOKENIZERS_PARALLELISM", 'true') m.setenv(STR_BACKEND_ENV_VAR, backend) @@ -74,7 +73,6 @@ def test_models( tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, kv_cache_dtype="auto", - disable_async_output_proc=disable_async_output_proc, ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) @@ -85,7 +83,6 @@ def test_models( tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) @@ -110,9 +107,6 @@ def test_models( ]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) -# Due to low-precision numerical divergence, this test is too sensitive for -# the async postprocessor -@pytest.mark.parametrize("disable_async_output_proc", [True]) def test_cpu_models( vllm_runner, example_prompts, @@ -120,7 +114,6 @@ def test_cpu_models( base_model: str, test_model: str, max_tokens: int, - disable_async_output_proc: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: """ @@ -138,7 +131,6 @@ def test_cpu_models( max_model_len=MAX_MODEL_LEN, dtype="bfloat16", kv_cache_dtype="auto", - disable_async_output_proc=disable_async_output_proc, ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) @@ -148,7 +140,6 @@ def test_cpu_models( max_model_len=MAX_MODEL_LEN, dtype="bfloat16", kv_cache_dtype=kv_cache_dtype, - disable_async_output_proc=disable_async_output_proc, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 9281579b71e74..b9601114a3183 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -7,7 +7,6 @@ from unittest.mock import patch import pytest from vllm import LLM -from vllm.engine.llm_engine import LLMEngine as V0LLMEngine from vllm.utils import GiB_bytes from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.engine.core import EngineCore as V1EngineCore @@ -61,10 +60,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, False)) # Avoid calling model.forward() - def _initialize_kv_caches_v0(self) -> None: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - def _initialize_kv_caches_v1(self, vllm_config): kv_cache_specs = self.model_executor.get_kv_cache_specs() scheduler_kv_cache_config = get_kv_cache_configs( @@ -76,12 +71,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config - with (patch.object(V0LLMEngine, "_initialize_kv_caches", - _initialize_kv_caches_v0), - patch.object(V1EngineCore, "_initialize_kv_caches", + with (patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: - m.setenv("VLLM_USE_V1", "0") + # NOTE(woosuk): skip the test for V0-only models + return + if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"): # Phi4FlashForCausalLM and MotifForCausalLM # only supports DIFFERENTIAL_FLASH_ATTN backend diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 4aa7bb7297893..9b376f2a260ac 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -42,6 +42,7 @@ def test_oot_registration_text_generation( assert rest == "" +@pytest.mark.skip(reason="This test is skipped because it failed on V1.") @create_new_process_for_each_test() def test_oot_registration_embedding( monkeypatch: pytest.MonkeyPatch, @@ -62,6 +63,7 @@ def test_oot_registration_embedding( image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") +@pytest.mark.skip(reason="This test is skipped because it failed on V1.") @create_new_process_for_each_test() def test_oot_registration_multimodal( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py index d6d43ca2f7e15..842e37ea26f67 100644 --- a/tests/models/test_terratorch.py +++ b/tests/models/test_terratorch.py @@ -5,7 +5,6 @@ import pytest import torch from tests.conftest import VllmRunner -from vllm.utils import set_default_torch_num_threads @pytest.mark.parametrize( @@ -25,19 +24,17 @@ def test_inference( prompt = dict(prompt_token_ids=[1], multi_modal_data=dict(pixel_values=pixel_values, location_coords=location_coords)) - with ( - set_default_torch_num_threads(1), - vllm_runner( - model, - runner="pooling", - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True, - # Limit the maximum number of sequences to avoid the - # test going OOM during the warmup run - max_num_seqs=32, - ) as vllm_model, - ): + with vllm_runner( + model, + runner="pooling", + dtype="half", + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + default_torch_num_threads=1, + ) as vllm_model: vllm_output = vllm_model.llm.encode(prompt) assert torch.equal( diff --git a/tests/models/utils.py b/tests/models/utils.py index 76c6e4823a12c..5da2382cef814 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -12,7 +12,7 @@ from transformers import PretrainedConfig from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.inputs import InputContext -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from .registry import HF_EXAMPLE_MODELS diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index da97cf7e2b40b..b431ad1ed0928 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -9,7 +9,6 @@ from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder, LlavaForConditionalGeneration, LlavaMultiModalProcessor, LlavaProcessingInfo) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -18,11 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY dummy_inputs=LlavaDummyInputsBuilder) class MyLlava(LlavaForConditionalGeneration): - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) + logits = super().compute_logits(hidden_states) if logits is not None: logits.zero_() logits[:, 0] += 1.0 diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py index 8c34407e3e071..a6fafff98e9c5 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py @@ -6,16 +6,14 @@ from typing import Optional import torch from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata class MyOPTForCausalLM(OPTForCausalLM): - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) + logits = super().compute_logits(hidden_states) if logits is not None: logits.zero_() logits[:, 0] += 1.0 diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index 6e2089ea2e0e2..1d7e4475011d0 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -7,15 +7,6 @@ import torch from vllm.plugins import load_general_plugins -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - def test_platform_plugins(): # simulate workload by running an example import runpy diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py index 8c21216108685..099869a82ad21 100644 --- a/tests/plugins_tests/test_scheduler_plugins.py +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -3,47 +3,18 @@ import pytest -from vllm.core.scheduler import Scheduler from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler -from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine.llm_engine import LLMEngine -class DummyV0Scheduler(Scheduler): - - def schedule(self): - raise Exception("Exception raised by DummyV0Scheduler") - - -class DummyV1Scheduler(V1Scheduler): +class DummyV1Scheduler(Scheduler): def schedule(self): raise Exception("Exception raised by DummyV1Scheduler") -def test_scheduler_plugins_v0(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - with pytest.raises(Exception) as exception_info: - - engine_args = EngineArgs( - model="facebook/opt-125m", - enforce_eager=True, # reduce test time - scheduler_cls=DummyV0Scheduler, - ) - - engine = LLMEngine.from_engine_args(engine_args=engine_args) - - sampling_params = SamplingParams(max_tokens=1) - engine.add_request("0", "foo", sampling_params) - engine.step() - - assert str( - exception_info.value) == "Exception raised by DummyV0Scheduler" - - def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -59,7 +30,7 @@ def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch): scheduler_cls=DummyV1Scheduler, ) - engine = V1LLMEngine.from_engine_args(engine_args=engine_args) + engine = LLMEngine.from_engine_args(engine_args=engine_args) sampling_params = SamplingParams(max_tokens=1) engine.add_request("0", "foo", sampling_params) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index b7949a488ad05..c0ab3fbb10622 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -357,6 +357,9 @@ def test_compressed_tensors_fp8(vllm_runner): assert output +@pytest.mark.skipif( + not current_platform.is_kv_cache_dtype_supported("fp8", None), + reason="FP8 KV cache is not supported on this device.") @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform.") def test_compressed_tensors_kv_cache(vllm_runner): @@ -738,4 +741,4 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, with vllm_runner(model, enforce_eager=True) as llm: perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) - assert perplexity <= exp_perplexity \ No newline at end of file + assert perplexity <= exp_perplexity diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 0320a5ef31a65..2960ffcbd9eab 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -10,13 +10,6 @@ from transformers import AutoModelForSeq2SeqLM from vllm.assets.audio import AudioAsset - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - # FIXME(zhuohan): The test can not pass if we: # 1. Increase max_tokens to 256. # 2. Increase beam_width to 8. diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index ea4a17dd2306f..1d77d37a5d581 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -9,13 +9,6 @@ import pytest from vllm import SamplingParams - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - # We also test with llama because it has generation_config to specify EOS # (past regression). MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"] diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py index 86fc14dc85f80..220a4a53f4671 100644 --- a/tests/samplers/test_ranks.py +++ b/tests/samplers/test_ranks.py @@ -8,12 +8,6 @@ from vllm import SamplingParams MODELS = ["distilbert/distilgpt2"] -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - """We can run both engines for this test.""" - pass - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_ranks( diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 45ddb2178722a..368238b3a7200 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -3,38 +3,52 @@ import pytest import torch +from vllm.config import SpeculativeConfig from vllm.model_executor.models.interfaces import supports_eagle3 -@pytest.mark.parametrize( - "model_path", - [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) -def test_llama(vllm_runner, example_prompts, model_path, monkeypatch): +@pytest.mark.parametrize("model_path", [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator"), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator"), +]) +def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, + monkeypatch): + """ + Test Eagle3 speculators models properly initialize speculative decoding. + + This test verifies: + 1. Eagle3 support is detected for the model + 2. Speculative config is automatically initialized from embedded config + 3. The draft model path is correctly set to the speculators model + 4. Speculative tokens count is valid + 5. Text generation works with speculative decoding enabled + """ # Set environment variable for V1 engine serialization monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + # Verify Eagle3 support is detected eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert eagle3_supported, f"Eagle3 should be supported for {model_path}" + + vllm_config = vllm_model.llm.llm_engine.vllm_config + + assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \ + "Speculative config should be initialized for speculators model" + + spec_config = vllm_config.speculative_config + assert spec_config.num_speculative_tokens > 0, \ + (f"Expected positive speculative tokens, " + f"got {spec_config.num_speculative_tokens}") + + assert spec_config.model == model_path, \ + f"Draft model should be {model_path}, got {spec_config.model}" vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) - print(vllm_outputs) - assert vllm_outputs - - -@pytest.mark.parametrize( - "model_path", - [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) -def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch): - # Set environment variable for V1 engine serialization - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - - with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: - eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported - - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + assert vllm_outputs, \ + f"No outputs generated for speculators model {model_path}" diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 42afdfa3c7468..fd5b5fad0999c 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -57,10 +57,19 @@ def llama_3p2_1b_files(): def _run_writer(input_dir, output_dir, weights_patterns, **kwargs): llm_sharded_writer = LLM(model=input_dir, **kwargs) - + # Check which engine version is being used + is_v1_engine = hasattr(llm_sharded_writer.llm_engine, "engine_core") # Dump worker states to output directory - llm_sharded_writer.llm_engine.model_executor.save_sharded_state( - path=output_dir) + if is_v1_engine: + # For V1 engine, we need to use engine_core.save_sharded_state + print("Using V1 engine save path") + llm_sharded_writer.llm_engine.engine_core.save_sharded_state( + path=output_dir) + else: + # For V0 engine + print("Using V0 engine save path") + model_executor = llm_sharded_writer.llm_engine.model_executor + model_executor.save_sharded_state(path=output_dir) # Copy metadata files to output directory for file in os.listdir(input_dir): @@ -91,8 +100,6 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, gpu_memory_utilization = 0.8 input_dir = llama_3p2_1b_files ctx = mp.get_context("spawn") - # The interface in v1 engine has changed, run in v1 engine will hang. - monkeypatch.setenv("VLLM_USE_V1", "0") # Run in separate processes for memory & CUDA isolation with TemporaryDirectory() as output_dir: @@ -100,7 +107,6 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, args=(input_dir, output_dir, weights_patterns), kwargs=dict( tensor_parallel_size=tp_size, - distributed_executor_backend="mp", gpu_memory_utilization=gpu_memory_utilization, enforce_eager=True, )) @@ -112,7 +118,6 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, p = ctx.Process(target=_run_generate, args=(input_dir, queue), kwargs=dict( - distributed_executor_backend="mp", enable_lora=enable_lora, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tp_size, @@ -133,7 +138,6 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, p = ctx.Process(target=_run_generate, args=(output_dir, queue), kwargs=dict( - distributed_executor_backend="mp", enable_lora=enable_lora, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tp_size, diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 15ea55afe963b..fe6c313d2966f 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -8,10 +8,7 @@ import pytest from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from vllm.inputs import token_inputs -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, @@ -217,193 +214,3 @@ def test_oov_decode(tokenizer, fast): assert decoded_text == '' assert out_ids == [len(tokenizer)] - - -@pytest.fixture -def detokenizer(tokenizer_name: str) -> Detokenizer: - tokenizer = get_tokenizer( - tokenizer_name, - tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", - trust_remote_code=False, - revision=None, - ) - - return Detokenizer(tokenizer) - - -@pytest.fixture(name="complete_sequence_token_ids") -def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer) -> list[int]: - return tokenizer(complete_sequence, add_special_tokens=False).input_ids - - -def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [] - return Sequence( - seq_id=0, - inputs=token_inputs(prompt_token_ids), - block_size=16, - ) - - -def create_dummy_logprobs( - complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: - return [{ - token_id: Logprob(logprob=0.0), - token_id + 1: Logprob(logprob=0.1) - } for token_id in complete_sequence_token_ids] - - -def create_dummy_prompt_logprobs( - complete_sequence_token_ids: list[int] -) -> list[Optional[dict[int, Any]]]: - # logprob for the first prompt token is None. - logprobs: list[Optional[dict[int, Any]]] = [None] - logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) - return logprobs - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) -def test_decode_sequence_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer, - skip_special_tokens: bool): - """Verify Detokenizer decodes logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - logprobs=2) - - # Run sequentially. - seq = create_sequence() - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) - sequential_logprobs_text_chosen_token: list[str] = [] - sequential_logprobs_text_other_token: list[str] = [] - for new_token, logprobs in zip(complete_sequence_token_ids, - dummy_logprobs): - seq.append_token_id(new_token, logprobs) - detokenizer.decode_sequence_inplace(seq, sampling_params) - sequential_logprobs_text_chosen_token.append( - seq.output_logprobs[-1][new_token].decoded_token) - sequential_logprobs_text_other_token.append( - seq.output_logprobs[-1][new_token + 1].decoded_token) - sequential_result = seq.output_text - - assert sequential_result == "".join(sequential_logprobs_text_chosen_token) - assert sequential_result != "".join(sequential_logprobs_text_other_token) - - if not skip_special_tokens: - # Text for logprobs for the chosen token should be the same as the - # generated text. Note that this will only be true if we skip - # special tokens. - assert sequential_result == complete_sequence - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer): - - # We want to use skip_special_tokens=False here but Mistral tokenizers - # don't support that. - if complete_sequence not in SPECIAL_TOKS_TRUTH: - skip_special_tokens = True - elif not isinstance(detokenizer.tokenizer, MistralTokenizer): - skip_special_tokens = False - else: - pytest.skip("MistralTokenizers don't support " - "skip_special_tokens=False") - return - """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - prompt_logprobs=1) - - # Run sequentially. - seq = create_sequence(complete_sequence_token_ids) - seq_group = SequenceGroup(request_id="1", - seqs=[seq], - sampling_params=sampling_params, - arrival_time=0.0) - dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, - dummy_logprobs, - position_offset=0) - # First logprob is None. - decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ - 1:] # type: ignore - - # decoded_prompt_logprobs doesn't contain the first token. - token_ids = complete_sequence_token_ids - tokenizer = detokenizer.tokenizer - text_full = tokenizer.decode(token_ids, - skip_special_tokens=skip_special_tokens) - text_first = tokenizer.decode(token_ids[0], - skip_special_tokens=skip_special_tokens) - text = text_full[len(text_first):] - - # Text for logprobs for the chosen token should be the same as the - # prompt text. Note that the first logprob is None. - assert text == "".join([ - logprobs[token_id].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) - assert text != "".join([ - logprobs[token_id + 1].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1]) -def test_decode_prompt_logprobs_chunked_prefill( - vllm_runner, - model, - chunked_prefill_token_size: int, - example_prompts, - monkeypatch, -): - # VLLM V1 does not use incremental detokenization for - # prompt logprobs, so this test strategy is irrelevant. - monkeypatch.setenv("VLLM_USE_V1", "0") - - max_num_seqs = 256 - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) - max_num_batched_tokens = chunked_prefill_token_size - - with vllm_runner(model, - dtype="half", - max_logprobs=5, - gpu_memory_utilization=0.5, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: - - vllm_sampling_params = SamplingParams(max_tokens=10, - logprobs=5, - prompt_logprobs=5, - temperature=0.0) - vllm_results = vllm_model.llm.generate( - example_prompts, sampling_params=vllm_sampling_params) - - for idx, result in enumerate(vllm_results): - assert result.prompt_logprobs is not None - assert result.prompt_logprobs[0] is None - - # Compared detokenized prompts ids to original prompt. - generated_string = "" - for (prompt_token, - prompt_logprobs) in zip(result.prompt_token_ids[1:], - result.prompt_logprobs[1:]): - # prompt_logprobs is a dict of the token_id: logprob - # We select the token_id corresponding to the actual prompt - # Decoded token in the detokenized string corresponding to this - # prompt token. - generated_string += prompt_logprobs[prompt_token].decoded_token - - assert generated_string == example_prompts[idx], ( - "Detokenized prompt logprobs do not match original prompt") diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 35153139350bf..57ace1fa22ac9 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import JambaToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "ai21labs/Jamba-tiny-dev" diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index ccb2acf512caf..f06fb2b9f2f04 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ToolCall) from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( Qwen3CoderToolParser) -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index c276a598aa68c..118c7534622e2 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 0bc22e4f1031c..c07ca0f56d6b0 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 0b7e103beca63..8a4fc15791b08 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -1,15 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 attention backends without GPUModelRunner dependency.""" +from functools import partial +from typing import Optional, Union import pytest import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm.config import ModelConfig +from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, set_kv_cache_layout) @@ -183,13 +188,19 @@ class MockAttentionLayer: self._v_scale_float = 1.0 -def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - layer_names: list[str], vllm_config, - device: torch.device, - common_attn_metadata: CommonAttentionMetadata, - query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor) -> torch.Tensor: +def run_attention_backend( + backend: _Backend, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + sliding_window: Optional[int] = None, +) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" # Handle special case for FLEX_ATTENTION_SLOW @@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, scale=scale, num_kv_heads=num_kv_heads, alibi_slopes=None, - sliding_window=None, + sliding_window=sliding_window, kv_cache_dtype="auto", ) @@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, return output -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_backend_correctness(batch_spec_name: str, model: str): +def _test_backend_correctness( + batch_spec: BatchSpec, + model: str, + backend_to_test: list[Union[_Backend, str]], + mask_mod, + *, + block_size: int = 16, + atol: float = 1e-2, + rtol: float = 1e-2, +): """ Test that all backends produce similar outputs to a reference implementation using torch.nn.functional.scaled_dot_product_attention. @@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ - batch_spec = BATCH_SPECS[batch_spec_name] + current_platform.seed_everything(42) vllm_config = create_vllm_config(model_name=model, max_model_len=max(batch_spec.seq_lens), + block_size=block_size, num_gpu_blocks=8192) device = torch.device("cuda:0") @@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str): num_kv_heads = vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() + sliding_window = vllm_config.model_config.get_sliding_window() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size scale = 1.0 / (head_size**0.5) @@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str): # Create causal mask: query token i attends to positions 0 to # (context_len + i) kv_len = s_len - offset = context_len - attn_mask = torch.full((q_len, kv_len), - float('-inf'), - device=device, - dtype=dtype) - for i in range(q_len): - attn_mask[i, :offset + i + 1] = 0.0 - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - attn_mask=attn_mask, - scale=scale, - enable_gqa=True) - # Convert back to (L, H, D) + final_mask_mod = partial(mask_mod, context_len=context_len) + block_mask = create_block_mask(final_mask_mod, + B=None, + H=None, + Q_LEN=q_len, + KV_LEN=kv_len, + device=device) + sdpa_out_i = flex_attention(q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + block_mask=block_mask, + scale=scale, + enable_gqa=True) + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) # Inputs for vLLM backends are just the new tokens @@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str): # 4. Run vLLM backends and compare # Note: flex_attention has known Triton kernel compatibility issues # with test infrastructures - for backend_name in BACKENDS_TO_TEST: + for backend_name in backend_to_test: # FlashAttentionm + FlexAttention: # [2, num_blocks, block_size, num_kv_heads, head_size] # FlashInfer: @@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str): 2, 3).contiguous().transpose(2, 3) set_kv_cache_layout("HND") - backend_output = run_attention_backend(backend_name, kv_cache_spec, - ["placeholder"], vllm_config, - device, common_attn_metadata, - query_vllm, key_vllm, - value_vllm, - kv_cache_for_backend) + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + ) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( @@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str): f"[{backend_name}] produced non-finite values") # Check numerical similarity - rtol = 1e-2 - atol = 5e-3 + def error_msg(msg: str, backend_name: str): + return (f"[{backend_name}] output differs from SDPA baseline. " + f"{msg}") - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() - max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() - all_close = torch.allclose(backend_output, + torch.testing.assert_close(backend_output, sdpa_output, rtol=rtol, - atol=atol) + atol=atol, + msg=partial(error_msg, + backend_name=backend_name)) - assert all_close, ( - f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_causal_backend_correctness(batch_spec_name: str, model: str): + """Test backend's correctness with causal attention.""" + + def causal_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + ): + return (q_idx + context_len) >= kv_idx + + batch_spec = BATCH_SPECS[batch_spec_name] + LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") else []) + SMALL_BLOCK_BACKENDS = [ + x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, + causal_mask_mod) + + # Fast FlexAttention needs to run with block_size=128 + if LARGE_BLOCK_BACKENDS: + _test_backend_correctness(batch_spec, + model, + LARGE_BLOCK_BACKENDS, + causal_mask_mod, + block_size=128) + + +SLIDING_WINDOW_BACKENDS_TO_TEST = [ + _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW" +] + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_medium", "large_decode", + "large_prefill" +]) +@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) +def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): + """Test backend's correctness with sliding window attention.""" + + def sliding_window_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + sliding_window: int, + ): + causal_mask = q_idx + context_len >= kv_idx + window_mask = q_idx + context_len - kv_idx < sliding_window + return causal_mask & window_mask + + batch_spec = BATCH_SPECS[batch_spec_name] + model_config = ModelConfig(model=model, + max_model_len=max(batch_spec.seq_lens)) + sliding_window = model_config.get_sliding_window() + sliding_window_mask_mod_fn = partial(sliding_window_mask_mod, + sliding_window=sliding_window) + + LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] + if is_torch_equal_or_newer("2.9.0.dev0") else []) + SMALL_BLOCK_BACKENDS = [ + x for x in SLIDING_WINDOW_BACKENDS_TO_TEST + if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, + sliding_window_mask_mod_fn) + + # Fast FlexAttention needs to run with block_size=128 + if LARGE_BLOCK_BACKENDS: + _test_backend_correctness(batch_spec, + model, + LARGE_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + block_size=128) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index a9632ce54eac8..bdb40be99aa3f 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, STOP_STRINGS, DummyOutputProcessorTestVectors, MockEngineCore) +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import (OutputProcessor, diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index bdd41eece2317..3a7bcb9571825 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -6,7 +6,6 @@ import pytest from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig -from vllm.platforms.interface import UnspecifiedPlatform from vllm.sampling_params import SamplingParams from vllm.v1.engine import processor as processor_mod from vllm.v1.engine.processor import Processor @@ -33,15 +32,6 @@ def _mk_processor(monkeypatch, "__post_init__", lambda self, *args: None, raising=True) - monkeypatch.setattr(UnspecifiedPlatform, - "is_async_output_supported", - classmethod(lambda cls, enforce_eager: True), - raising=True) - monkeypatch.setattr( - ModelConfig, - "verify_async_output_proc", - lambda self, parallel_config, speculative_config, device_config: None, - raising=True) monkeypatch.setattr(ModelConfig, "verify_with_parallel_config", lambda self, parallel_config: None, diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 28c24f62895ab..f6b8a18dd7c2c 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -29,24 +29,6 @@ def test_unsupported_configs(monkeypatch): }, ).create_engine_config() - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - preemption_mode="swap", - ).create_engine_config() - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - disable_async_output_proc=True, - ).create_engine_config() - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - scheduler_delay_factor=1.2, - ).create_engine_config() - def test_enable_by_default_fallback(monkeypatch): with monkeypatch.context() as m: diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dfde67e1713c5..ab7ef2112b083 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -4,19 +4,14 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, - Protocol, Set, Tuple, Type, TypeVar) +from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple, + Type, TypeVar) import torch from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap -if TYPE_CHECKING: - from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - ModelRunnerInputBuilderBase) - class AttentionType: """ @@ -170,7 +165,7 @@ class AttentionState(ABC, Generic[T]): lifetime of the model runner.""" @abstractmethod - def __init__(self, runner: "ModelRunnerBase"): + def __init__(self, runner: Any): ... @abstractmethod @@ -210,7 +205,7 @@ class AttentionState(ABC, Generic[T]): ... @abstractmethod - def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + def begin_forward(self, model_input) -> None: """Prepare state for forward pass.""" ... @@ -219,7 +214,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" @abstractmethod - def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + def __init__(self, input_builder) -> None: """Create the builder, remember some configuration and parameters.""" raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index a7d0e3afb517f..87a4558e377d0 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from einops import rearrange @@ -34,9 +34,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - logger = init_logger(__name__) @@ -329,7 +326,7 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata): class DifferentialFlashAttentionMetadataBuilder( AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): - def __init__(self, input_builder: "ModelInputForGPUBuilder"): + def __init__(self, input_builder): self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window @@ -350,9 +347,8 @@ class DifferentialFlashAttentionMetadataBuilder( self.num_decode_tokens = 0 self.has_prefix_cache_hit = False - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index 85957bea1e26d..de47bb8ebd8fb 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -4,7 +4,7 @@ """ import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch import torch.distributed @@ -22,9 +22,6 @@ from vllm.utils import async_tensor_h2d from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache, sparse_attn_func) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - logger = init_logger(__name__) @@ -224,9 +221,8 @@ class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): super().prepare() self.orig_seq_lens: List[int] = [] - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): super()._add_seq_group(inter_data, chunked_prefill_enabled, prefix_cache_hit) for prompt_len, seq_len in zip(inter_data.prompt_lens, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 78c768f92d3c2..edb3afb4aa075 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import torch @@ -31,9 +31,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - logger = init_logger(__name__) @@ -312,7 +309,7 @@ class FlashAttentionMetadata(AttentionMetadata): class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - def __init__(self, input_builder: "ModelInputForGPUBuilder"): + def __init__(self, input_builder): self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window @@ -332,9 +329,8 @@ class FlashAttentionMetadataBuilder( self.num_decode_tokens = 0 self.has_prefix_cache_hit = False - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 789393eb39a73..826b63e1ccda8 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -193,8 +193,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type, TypeVar) +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import torch @@ -233,9 +232,6 @@ except ImportError: except ImportError: flash_attn_varlen_func = None -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - is_hip = current_platform.is_rocm() @@ -638,7 +634,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ BLOCK_TABLE_EXTENDER: list[list[int]] = [] - def __init__(self, input_builder: "ModelInputForGPUBuilder"): + def __init__(self, input_builder): self.input_builder = input_builder self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window @@ -668,9 +664,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): self.num_decode_tokens = 0 self.has_prefix_cache_hit = False - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index e630a6c6de8c4..f82d28938f45c 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import torch @@ -13,9 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState from vllm.multimodal import MultiModalPlaceholderMap - -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder) from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that @@ -204,7 +201,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata): class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): - def __init__(self, input_builder: "ModelInputForGPUBuilder"): + def __init__(self, input_builder): self.input_builder = input_builder self.runner = input_builder.runner @@ -220,9 +217,7 @@ class PlaceholderAttentionMetadataBuilder( self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. """ diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index a2e9710437d95..587d08858b92b 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Type, Union +from typing import Optional, Type, Union import torch @@ -19,9 +19,6 @@ from vllm.attention.backends.utils import (compute_slot_mapping, from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, get_aiter_mla_metadata) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - def is_aiter_mla_enabled() -> bool: return envs.VLLM_ROCM_USE_AITER \ @@ -110,7 +107,7 @@ class AiterMLAMetadata(MLACommonMetadata): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - def __init__(self, input_builder: "ModelInputForGPUBuilder"): + def __init__(self, input_builder): super().__init__(input_builder) assert self.block_size == 1, "AITER MLA requires only block size 1." diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 7b6c426b0f851..b28e6a4237cb6 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -5,8 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -21,9 +20,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.worker.model_runner_base import ModelRunnerBase - # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " @@ -35,9 +31,6 @@ PAD_SLOT_ID = -1 # if we have at least this many elements. Could be tuned further. _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder - def is_block_tables_empty(block_tables: Union[None, Dict]): """ @@ -129,7 +122,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): _metadata_cls: Type[TAttentionMetadata] - def __init__(self, input_builder: "ModelInputForGPUBuilder"): + def __init__(self, input_builder): self.input_builder = input_builder self.runner = input_builder.runner @@ -149,9 +142,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool): is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables @@ -291,7 +282,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): class CommonAttentionState(AttentionState): - def __init__(self, runner: "ModelRunnerBase"): + def __init__(self, runner): self.runner = runner self._is_graph_capturing = False diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index ddd8de4324f6f..e31a78ba33baa 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -454,9 +454,6 @@ class VllmConfig: self.try_verify_and_update_config() if self.model_config is not None: - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_dual_chunk_attention_config( self.load_config) @@ -877,7 +874,6 @@ class VllmConfig: f"served_model_name={self.model_config.served_model_name}, " f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa - f"use_async_output_proc={self.model_config.use_async_output_proc}, " f"pooler_config={self.model_config.pooler_config!r}, " f"compilation_config={self.compilation_config!r}") diff --git a/vllm/config/model.py b/vllm/config/model.py index 921322bb475c5..95fe52883db0d 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -27,8 +27,7 @@ from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - is_interleaved, maybe_override_with_speculators_target_model, - try_get_generation_config, try_get_safetensors_metadata, + is_interleaved, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.runai_utils import (ObjectStorageModel, is_runai_obj_uri) @@ -223,8 +222,6 @@ class ModelConfig: that this name(s) will also be used in `model_name` tag content of prometheus metrics, if multiple names provided, metrics tag will take the first one.""" - use_async_output_proc: bool = True - """Whether to use async output processor.""" config_format: Union[str, ConfigFormat] = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it @@ -418,15 +415,6 @@ class ModelConfig: self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: raise ValueError( @@ -1119,37 +1107,6 @@ class ModelConfig: raise ValueError("please set VLLM_ATTENTION_BACKEND to " f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}") - def verify_async_output_proc(self, parallel_config, speculative_config, - device_config) -> None: - if not self.use_async_output_proc: - # Nothing to check - return - - if parallel_config.pipeline_parallel_size > 1: - self.use_async_output_proc = False - return - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - from vllm.platforms import current_platform - if not current_platform.is_async_output_supported(self.enforce_eager): - self.use_async_output_proc = False - return - - if envs.VLLM_USE_RAY_SPMD_WORKER: - self.use_async_output_proc = False - return - - # Async postprocessor is not necessary for pooling models - # since there is no token generation - if self.runner_type == "pooling": - self.use_async_output_proc = False - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - if speculative_config: - self.use_async_output_proc = False - def verify_with_parallel_config( self, parallel_config: ParallelConfig, @@ -1173,15 +1130,12 @@ class ModelConfig: self._verify_with_expert_parallelism() pipeline_parallel_size = parallel_config.pipeline_parallel_size - if pipeline_parallel_size > 1: - if not self.registry.is_pp_supported_model(self.architectures, - self): - raise NotImplementedError( - "Pipeline parallelism is not supported for this model. " - "Supported models implement the `SupportsPP` interface.") - - if self.use_async_output_proc: - self.use_async_output_proc = False + if (pipeline_parallel_size > 1 + and not self.registry.is_pp_supported_model( + self.architectures, self)): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") def get_sliding_window(self) -> Optional[int]: """Get the sliding window size from the HF text config if present.""" diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index f0f67bab9d6ff..daf094d2df5c8 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -3,7 +3,7 @@ import hashlib from dataclasses import field -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from pydantic import SkipValidation, model_validator from pydantic.dataclasses import dataclass @@ -18,7 +18,6 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] -PreemptionMode = Literal["swap", "recompute"] SchedulerPolicy = Literal["fcfs", "priority"] @@ -78,10 +77,6 @@ class SchedulerConfig: 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" - delay_factor: float = 0.0 - """Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -103,14 +98,6 @@ class SchedulerConfig: NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" - preemption_mode: Optional[PreemptionMode] = None - """Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead.""" - send_delta_data: bool = False """Private API. If used, scheduler sends delta data to workers instead of an entire data. It should be enabled only diff --git a/vllm/core/__init__.py b/vllm/core/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/core/block/__init__.py b/vllm/core/block/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py deleted file mode 100644 index 444bb25f2830a..0000000000000 --- a/vllm/core/block/block_table.py +++ /dev/null @@ -1,399 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from typing import List, Optional - -from vllm.core.block.common import BlockList -from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator -from vllm.utils import Device, cdiv, chunk_list - - -class BlockTable: - """A class to manage blocks for a specific sequence. - - The BlockTable maps a sequence of tokens to a list of blocks, where each - block represents a contiguous memory allocation for a portion of the - sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is - responsible for allocating and freeing memory for the blocks. - - Args: - block_size (int): The maximum number of tokens that can be stored in a - single block. - block_allocator (DeviceAwareBlockAllocator): The block allocator used to - manage memory for the blocks. - _blocks (Optional[List[Block]], optional): An optional list of existing - blocks to initialize the BlockTable with. If not provided, an empty - BlockTable is created. - max_block_sliding_window (Optional[int], optional): The number of - blocks to keep around for each sequence. If None, all blocks - are kept (eg., when sliding window is not used). - It should at least fit the sliding window size of the model. - - Attributes: - _block_size (int): The maximum number of tokens that can be stored in a - single block. - _allocator (DeviceAwareBlockAllocator): The block allocator used to - manage memory for the blocks. - _blocks (Optional[List[Block]]): The list of blocks managed by this - BlockTable. - _num_full_slots (int): The number of tokens currently stored in the - blocks. - """ - - def __init__( - self, - block_size: int, - block_allocator: DeviceAwareBlockAllocator, - _blocks: Optional[List[Block]] = None, - max_block_sliding_window: Optional[int] = None, - ): - self._block_size = block_size - self._allocator = block_allocator - if _blocks is None: - _blocks = [] - self._blocks: BlockList = BlockList(_blocks) - - self._max_block_sliding_window = max_block_sliding_window - self._num_full_slots = self._get_num_token_ids() - - @staticmethod - def get_num_required_blocks(token_ids: List[int], - block_size: int, - num_lookahead_slots: int = 0) -> int: - """Calculates the minimum number of blocks required to store a given - sequence of token IDs along with any look-ahead slots that may be - required (like in multi-step + chunked-prefill). - - This assumes worst-case scenario, where every block requires a new - allocation (e.g. ignoring prefix caching). - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - block_size (int): The maximum number of tokens that can be stored in - a single block. - num_lookahead_slots (int): look-ahead slots that the sequence may - require. - - Returns: - int: The minimum number of blocks required to store the given - sequence of token IDs along with any required look-ahead slots. - """ - return cdiv(len(token_ids) + num_lookahead_slots, block_size) - - def allocate(self, - token_ids: List[int], - device: Device = Device.GPU, - extra_hash: Optional[int] = None) -> None: - """Allocates memory blocks for storing the given sequence of token IDs. - - This method allocates the required number of blocks to store the given - sequence of token IDs. - - Args: - token_ids (List[int]): The sequence of token IDs to be stored. - device (Device, optional): The device on which the blocks should be - allocated. Defaults to Device.GPU. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefixcaching block. - """ - assert not self._is_allocated - assert token_ids - blocks = self._allocate_blocks_for_token_ids(prev_block=None, - token_ids=token_ids, - device=device, - extra_hash=extra_hash) - self.update(blocks) - self._num_full_slots = len(token_ids) - - def update(self, blocks: List[Block]) -> None: - """Resets the table to the newly provided blocks - (with their corresponding block ids) - """ - self._blocks.update(blocks) - - def append_token_ids(self, - token_ids: List[int], - num_lookahead_slots: int = 0, - num_computed_slots: Optional[int] = None, - extra_hash: Optional[int] = None) -> None: - """Appends a sequence of token IDs to the existing blocks in the - BlockTable. - - This method appends the given sequence of token IDs to the existing - blocks in the BlockTable. If there is not enough space in the existing - blocks, new blocks are allocated using the `ensure_num_empty_slots` - method to accommodate the additional tokens. - - The token IDs are divided into chunks of size `block_size` (except for - the first chunk, which may be smaller), and each chunk is appended to a - separate block. - - Args: - token_ids (List[int]): The sequence of token IDs to be appended. - num_computed_slots (Optional[int]): The number of KV cache slots - that are already filled (computed). - When sliding window is enabled, this is used to compute how many - blocks to drop at the front of the sequence. - Without sliding window, None can be passed. - Without chunked prefill, it should be the same as - _num_full_slots. - extra_hash (Optional[int]): The hash value of additional - factors such as adapters that influence the block, apart - from the token_ids. - """ - assert self._is_allocated, "no blocks have been allocated" - assert len(self._blocks) > 0 - - # Drop blocks that are no longer needed due to sliding window - if self._max_block_sliding_window is not None: - null_block = self._allocator.allocate_or_get_null_block() - assert num_computed_slots is not None - end_block_idx = (num_computed_slots // - self._block_size) - self._max_block_sliding_window - for idx in range(0, end_block_idx): - b = self._blocks[idx] - if b is not null_block: - self._allocator.free(b) - self._blocks[idx] = null_block - - # Ensure there are enough empty slots for the new tokens plus - # lookahead slots - self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + - num_lookahead_slots, - extra_hash=extra_hash) - - # Update the blocks with the new tokens - first_block_idx = self._num_full_slots // self._block_size - token_blocks = self._chunk_token_blocks_for_append(token_ids) - - for i, token_block in enumerate(token_blocks): - self._blocks.append_token_ids(first_block_idx + i, token_block) - - self._num_full_slots += len(token_ids) - - def ensure_num_empty_slots(self, - num_empty_slots: int, - extra_hash: Optional[int] = None) -> None: - """Ensures that the BlockTable has at least the specified number of - empty slots available. - - This method checks if the BlockTable has enough empty slots (i.e., - available space) to accommodate the requested number of tokens. If not, - it allocates additional blocks on the GPU to ensure that the required - number of empty slots is available. - - Args: - num_empty_slots (int): The minimum number of empty slots required. - extra_hash (Optional[int]): The hash value of additional - factors such as adapters that influence the block, apart - from the token_ids. - """ - # Currently the block table only supports - # appending tokens to GPU blocks. - device = Device.GPU - assert self._is_allocated - - if self._num_empty_slots >= num_empty_slots: - return - - slots_to_allocate = num_empty_slots - self._num_empty_slots - blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) - - for _ in range(blocks_to_allocate): - assert len(self._blocks) > 0 - self._blocks.append( - self._allocator.allocate_mutable_block( - prev_block=self._blocks[-1], - device=device, - extra_hash=extra_hash)) - - def fork(self) -> "BlockTable": - """Creates a new BlockTable instance with a copy of the blocks from the - current instance. - - This method creates a new BlockTable instance with the same block size, - block allocator, and a copy of the blocks from the current instance. The - new BlockTable has its own independent set of blocks, but shares the - same underlying memory allocation with the original BlockTable. - - Returns: - BlockTable: A new BlockTable instance with a copy of the blocks from - the current instance. - """ - assert self._is_allocated - assert len(self._blocks) > 0 - forked_blocks = self._allocator.fork(self._blocks[-1]) - return BlockTable( - block_size=self._block_size, - block_allocator=self._allocator, - _blocks=forked_blocks, - max_block_sliding_window=self._max_block_sliding_window, - ) - - def free(self) -> None: - """Frees the memory occupied by the blocks in the BlockTable. - - This method iterates over all the blocks in the `_blocks` list and calls - the `free` method of the `_allocator` object to release the memory - occupied by each block. After freeing all the blocks, the `_blocks` list - is set to `None`. - """ - for block in self.blocks: - self._allocator.free(block) - self._blocks.reset() - - @property - def physical_block_ids(self) -> List[int]: - """Returns a list of physical block indices for the blocks in the - BlockTable. - - This property returns a list of integers, where each integer represents - the physical block index of a corresponding block in the `_blocks` list. - The physical block index is a unique identifier for the memory location - occupied by the block. - - Returns: - List[int]: A list of physical block indices for the blocks in the - BlockTable. - """ - return self._blocks.ids() - - def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: - """Get the number of "unseen" tokens in the sequence. - - Unseen tokens are tokens in the sequence corresponding to this block - table, but are not yet appended to this block table. - - Args: - sequence_token_ids (List[int]): The list of token ids in the - sequence. - - Returns: - List[int]: The postfix of sequence_token_ids that has not yet been - appended to the block table. - """ - - # Since the block table is append-only, the unseen token ids are the - # ones after the appended ones. - return sequence_token_ids[self.num_full_slots:] - - def _allocate_blocks_for_token_ids( - self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> List[Block]: - blocks: List[Block] = [] - - block_token_ids = [] - tail_token_ids = [] - for cur_token_ids in chunk_list(token_ids, self._block_size): - if len(cur_token_ids) == self._block_size: - block_token_ids.append(cur_token_ids) - else: - tail_token_ids.append(cur_token_ids) - - if block_token_ids: - blocks.extend( - self._allocator.allocate_immutable_blocks( - prev_block, - block_token_ids=block_token_ids, - device=device, - extra_hash=extra_hash)) - prev_block = blocks[-1] - - if tail_token_ids: - assert len(tail_token_ids) == 1 - cur_token_ids = tail_token_ids[0] - - block = self._allocator.allocate_mutable_block( - prev_block=prev_block, device=device, extra_hash=extra_hash) - block.append_token_ids(cur_token_ids) - - blocks.append(block) - - return blocks - - def _get_all_token_ids(self) -> List[int]: - # NOTE: This function is O(seq_len); use sparingly. - token_ids: List[int] = [] - - if not self._is_allocated: - return token_ids - - for block in self.blocks: - token_ids.extend(block.token_ids) - - return token_ids - - def _get_num_token_ids(self) -> int: - res = 0 - for block in self.blocks: - res += len(block.token_ids) - - return res - - @property - def _is_allocated(self) -> bool: - return len(self._blocks) > 0 - - @property - def blocks(self) -> List[Block]: - return self._blocks.list() - - @property - def _num_empty_slots(self) -> int: - assert self._is_allocated - return len(self._blocks) * self._block_size - self._num_full_slots - - @property - def num_full_slots(self) -> int: - """Returns the total number of tokens currently stored in the - BlockTable. - - Returns: - int: The total number of tokens currently stored in the BlockTable. - """ - return self._num_full_slots - - def get_num_blocks_touched_by_append_slots( - self, token_ids: List[int], num_lookahead_slots: int) -> int: - """Determine how many blocks will be "touched" by appending the token - ids. - - This is required for the scheduler to determine whether a sequence can - continue generation, or if it must be preempted. - """ - # Math below is equivalent to: - # all_token_ids = token_ids + [-1] * num_lookahead_slots - # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) - # return len(token_blocks) - - num_token_ids = len(token_ids) + num_lookahead_slots - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - num_token_blocks = (1 + math.ceil( - (num_token_ids - first_chunk_size) / self._block_size)) - return num_token_blocks - - def _chunk_token_blocks_for_append( - self, token_ids: List[int]) -> List[List[int]]: - """Split the token ids into block-sized chunks so they can be easily - appended to blocks. The first such "token block" may have less token ids - than the block size, since the last allocated block may be partially - full. - - If no token ids are provided, then no chunks are returned. - """ - - if not token_ids: - return [] - - first_chunk_size = self._block_size - (self._num_full_slots % - self._block_size) - token_blocks = [token_ids[:first_chunk_size]] - token_blocks.extend( - chunk_list(token_ids[first_chunk_size:], self._block_size)) - return token_blocks diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py deleted file mode 100644 index a337007a9eaa6..0000000000000 --- a/vllm/core/block/common.py +++ /dev/null @@ -1,371 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections import deque -from dataclasses import dataclass -from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple - -from vllm.core.block.interfaces import Block, BlockAllocator - -BlockId = int -RefCount = int - - -class RefCounterProtocol(Protocol): - - def incr(self, block_id: BlockId) -> RefCount: - raise NotImplementedError - - def decr(self, block_id: BlockId) -> RefCount: - raise NotImplementedError - - def get(self, block_id: BlockId) -> RefCount: - raise NotImplementedError - - -class RefCounter(RefCounterProtocol): - """A class for managing reference counts for a set of block indices. - - The RefCounter class maintains a dictionary that maps block indices to their - corresponding reference counts. It provides methods to increment, decrement, - and retrieve the reference count for a given block index. - - Args: - all_block_indices (Iterable[BlockId]): An iterable of block indices - to initialize the reference counter with. - """ - - def __init__(self, all_block_indices: Iterable[BlockId]): - deduped = set(all_block_indices) - self._refcounts: Dict[BlockId, RefCount] = { - index: 0 - for index in deduped - } - - def incr(self, block_id: BlockId) -> RefCount: - assert block_id in self._refcounts - pre_incr_refcount = self._refcounts[block_id] - - assert pre_incr_refcount >= 0 - - post_incr_refcount = pre_incr_refcount + 1 - self._refcounts[block_id] = post_incr_refcount - return post_incr_refcount - - def decr(self, block_id: BlockId) -> RefCount: - assert block_id in self._refcounts - refcount = self._refcounts[block_id] - - assert refcount > 0 - refcount -= 1 - - self._refcounts[block_id] = refcount - - return refcount - - def get(self, block_id: BlockId) -> RefCount: - assert block_id in self._refcounts - return self._refcounts[block_id] - - def as_readonly(self) -> "ReadOnlyRefCounter": - return ReadOnlyRefCounter(self) - - -class ReadOnlyRefCounter(RefCounterProtocol): - """A read-only view of the RefCounter class. - - The ReadOnlyRefCounter class provides a read-only interface to access the - reference counts maintained by a RefCounter instance. It does not allow - modifications to the reference counts. - - Args: - refcounter (RefCounter): The RefCounter instance to create a read-only - view for. - """ - - def __init__(self, refcounter: RefCounter): - self._refcounter = refcounter - - def incr(self, block_id: BlockId) -> RefCount: - raise ValueError("Incr not allowed") - - def decr(self, block_id: BlockId) -> RefCount: - raise ValueError("Decr not allowed") - - def get(self, block_id: BlockId) -> RefCount: - return self._refcounter.get(block_id) - - -class CopyOnWriteTracker: - """A class for tracking and managing copy-on-write operations for blocks. - - The CopyOnWriteTracker class maintains a mapping of source block indices to - their corresponding copy-on-write destination block indices. It works in - conjunction with a RefCounter. - - Args: - refcounter (RefCounter): The reference counter used to track block - reference counts. - """ - - def __init__(self, refcounter: RefCounterProtocol): - self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] - self._refcounter = refcounter - - def is_appendable(self, block: Block) -> bool: - """Checks if the block is shared or not. If shared, then it cannot - be appended and needs to be duplicated via copy-on-write - """ - block_id = block.block_id - if block_id is None: - return True - - refcount = self._refcounter.get(block_id) - return refcount <= 1 - - def record_cow(self, src_block_id: Optional[BlockId], - trg_block_id: Optional[BlockId]) -> None: - """Records a copy-on-write operation from source to target block id - Args: - src_block_id (BlockId): The source block id from which to copy - the data - trg_block_id (BlockId): The target block id to which the data - is copied - """ - assert src_block_id is not None - assert trg_block_id is not None - self._copy_on_writes.append((src_block_id, trg_block_id)) - - def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: - """Clears the copy-on-write tracking information and returns the current - state. - - This method returns a list mapping source block indices to - destination block indices for the current copy-on-write operations. - It then clears the internal tracking information. - - Returns: - List[Tuple[BlockId, BlockId]]: A list mapping source - block indices to destination block indices for the - current copy-on-write operations. - """ - cows = self._copy_on_writes - self._copy_on_writes = [] - return cows - - -class BlockPool: - """Used to pre-allocate block objects, in order to avoid excessive python - object allocations/deallocations. - The pool starts from "pool_size" objects and will increase to more objects - if necessary - - Note that multiple block objects may point to the same physical block id, - which is why this pool is needed, so that it will be easier to support - prefix caching and more complicated sharing of physical blocks. - """ - - def __init__(self, block_size: int, create_block: Block.Factory, - allocator: BlockAllocator, pool_size: int): - self._block_size = block_size - self._create_block = create_block - self._allocator = allocator - self._pool_size = pool_size - assert self._pool_size >= 0 - - self._free_ids: Deque[int] = deque(range(self._pool_size)) - self._pool = [] - for i in range(self._pool_size): - self._pool.append( - self._create_block(prev_block=None, - token_ids=[], - block_size=self._block_size, - allocator=self._allocator, - block_id=None, - extra_hash=None)) - - def increase_pool(self): - """Doubles the internal pool size - """ - cur_pool_size = self._pool_size - new_pool_size = cur_pool_size * 2 - self._pool_size = new_pool_size - - self._free_ids += deque(range(cur_pool_size, new_pool_size)) - - for i in range(cur_pool_size, new_pool_size): - self._pool.append( - self._create_block(prev_block=None, - token_ids=[], - block_size=self._block_size, - allocator=self._allocator, - block_id=None, - extra_hash=None)) - - def init_block(self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - physical_block_id: Optional[int], - extra_hash: Optional[int] = None) -> Block: - if len(self._free_ids) == 0: - self.increase_pool() - assert len(self._free_ids) > 0 - - pool_id = self._free_ids.popleft() - - block = self._pool[pool_id] - block.__init__( # type: ignore[misc] - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - allocator=block._allocator, # type: ignore[attr-defined] - block_id=physical_block_id, - extra_hash=extra_hash) - block.pool_id = pool_id # type: ignore[attr-defined] - return block - - def free_block(self, block: Block) -> None: - self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] - - -class BlockList: - """This class is an optimization to allow fast-access to physical - block ids. It maintains a block id list that is updated with the - block list and this avoids the need to reconstruct the block id - list on every iteration of the block manager - """ - - def __init__(self, blocks: List[Block]): - self._blocks: List[Block] = [] - self._block_ids: List[int] = [] - - self.update(blocks) - - def _add_block_id(self, block_id: Optional[BlockId]) -> None: - assert block_id is not None - self._block_ids.append(block_id) - - def _update_block_id(self, block_index: int, - new_block_id: Optional[BlockId]) -> None: - assert new_block_id is not None - self._block_ids[block_index] = new_block_id - - def update(self, blocks: List[Block]): - self._blocks = blocks - - # Cache block ids for fast query - self._block_ids = [] - for block in self._blocks: - self._add_block_id(block.block_id) - - def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: - block = self._blocks[block_index] - prev_block_id = block.block_id - - block.append_token_ids(token_ids) - - # CoW or promotion may update the internal block_id - if prev_block_id != block.block_id: - self._update_block_id(block_index, block.block_id) - - def append(self, new_block: Block): - self._blocks.append(new_block) - self._add_block_id(new_block.block_id) - - def __len__(self) -> int: - return len(self._blocks) - - def __getitem__(self, block_index: int) -> Block: - return self._blocks[block_index] - - def __setitem__(self, block_index: int, new_block: Block) -> None: - self._blocks[block_index] = new_block - self._update_block_id(block_index, new_block.block_id) - - def reset(self): - self._blocks = [] - self._block_ids = [] - - def list(self) -> List[Block]: - return self._blocks - - def ids(self) -> List[int]: - return self._block_ids - - -@dataclass -class CacheMetricData: - """A utility dataclass to maintain cache metric. - To avoid overflow, we maintain the hit rate in block granularity, so that - we can maintain a single hit rate for n_completed_block x block_size, - and calculate the real time hit rate by the following: - BS = The number of queries per block. - nB = The number of completed blocks. - HR = hit rate of (nB x BS) queries. - Q = current number of queries (< BS). - H = current number of hits (< BS). - hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) - """ - num_completed_blocks: int = 0 - completed_block_cache_hit_rate: float = 0.0 - num_incompleted_block_queries: int = 0 - num_incompleted_block_hit: int = 0 - block_size: int = 1000 - - def query(self, hit: bool): - self.num_incompleted_block_queries += 1 - self.num_incompleted_block_hit += 1 if hit else 0 - - # When a block is completed, update the cache hit rate - # and reset the incomplete numbers. - if self.num_incompleted_block_queries == self.block_size: - hit_rate = (self.num_incompleted_block_hit / - self.num_incompleted_block_queries) - self.completed_block_cache_hit_rate = ( - self.completed_block_cache_hit_rate * self.num_completed_blocks - + hit_rate) / (self.num_completed_blocks + 1) - self.num_incompleted_block_queries = 0 - self.num_incompleted_block_hit = 0 - self.num_completed_blocks += 1 - - def get_hit_rate(self): - incomplete_ratio = self.num_incompleted_block_queries / self.block_size - total_blocks = self.num_completed_blocks + incomplete_ratio - if total_blocks == 0: - return 0.0 - - completed_block_hit, incompleted_block_hit = 0.0, 0.0 - if self.num_completed_blocks > 0: - completed_block_hit = (self.completed_block_cache_hit_rate * - self.num_completed_blocks) - if self.num_incompleted_block_queries > 0: - incompleted_hit_rate = (self.num_incompleted_block_hit / - self.num_incompleted_block_queries) - incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) - return (completed_block_hit + incompleted_block_hit) / total_blocks - - -def get_all_blocks_recursively(last_block: Block) -> List[Block]: - """Retrieves all the blocks in a sequence starting from the last block. - - This function recursively traverses the sequence of blocks in reverse order, - starting from the given last block, and returns a list of all the blocks in - the sequence. - - Args: - last_block (Block): The last block in the sequence. - - Returns: - List[Block]: A list of all the blocks in the sequence, in the order they - appear. - """ - - def recurse(block: Block, lst: List[Block]) -> None: - if block.prev_block is not None: - recurse(block.prev_block, lst) - lst.append(block) - - all_blocks: List[Block] = [] - recurse(last_block, all_blocks) - return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py deleted file mode 100644 index 92bc5e157e148..0000000000000 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ /dev/null @@ -1,439 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Dict, FrozenSet, List, Optional, Tuple - -from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, - DeviceAwareBlockAllocator) -from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator -from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator -from vllm.utils import Device - - -class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - """A block allocator that can allocate blocks on both CPU and GPU memory. - - This class implements the `DeviceAwareBlockAllocator` interface and provides - functionality for allocating and managing blocks of memory on both CPU and - GPU devices. - - The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU - blocks, and allows for allocation, deallocation, forking, and swapping of - blocks across these memory pools. - """ - - @staticmethod - def create( - allocator_type: str, - num_gpu_blocks: int, - num_cpu_blocks: int, - block_size: int, - ) -> DeviceAwareBlockAllocator: - """Creates a CpuGpuBlockAllocator instance with the specified - configuration. - - This static method creates and returns a CpuGpuBlockAllocator instance - based on the provided parameters. It initializes the CPU and GPU block - allocators with the specified number of blocks, block size, and - allocator type. - - Args: - allocator_type (str): The type of block allocator to use for CPU - and GPU blocks. Currently supported values are "naive" and - "prefix_caching". - num_gpu_blocks (int): The number of blocks to allocate for GPU - memory. - num_cpu_blocks (int): The number of blocks to allocate for CPU - memory. - block_size (int): The size of each block in number of tokens. - - Returns: - DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the - specified configuration. - - Notes: - - The block IDs are assigned contiguously, with GPU block IDs coming - before CPU block IDs. - """ - reserved_blocks = 0 - block_ids = list( - range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) - num_gpu_blocks -= reserved_blocks - gpu_block_ids = block_ids[:num_gpu_blocks] - cpu_block_ids = block_ids[num_gpu_blocks:] - - if allocator_type == "naive": - gpu_allocator: BlockAllocator = NaiveBlockAllocator( - create_block=NaiveBlock, # type: ignore - num_blocks=num_gpu_blocks, - block_size=block_size, - block_ids=gpu_block_ids, - ) - - cpu_allocator: BlockAllocator = NaiveBlockAllocator( - create_block=NaiveBlock, # type: ignore - num_blocks=num_cpu_blocks, - block_size=block_size, - block_ids=cpu_block_ids, - ) - elif allocator_type == "prefix_caching": - gpu_allocator = PrefixCachingBlockAllocator( - num_blocks=num_gpu_blocks, - block_size=block_size, - block_ids=gpu_block_ids, - ) - - cpu_allocator = PrefixCachingBlockAllocator( - num_blocks=num_cpu_blocks, - block_size=block_size, - block_ids=cpu_block_ids, - ) - else: - raise ValueError(f"Unknown allocator type {allocator_type=}") - - return CpuGpuBlockAllocator( - cpu_block_allocator=cpu_allocator, - gpu_block_allocator=gpu_allocator, - ) - - def __init__(self, cpu_block_allocator: BlockAllocator, - gpu_block_allocator: BlockAllocator): - assert not ( - cpu_block_allocator.all_block_ids - & gpu_block_allocator.all_block_ids - ), "cpu and gpu block allocators can't have intersection of block ids" - - self._allocators = { - Device.CPU: cpu_block_allocator, - Device.GPU: gpu_block_allocator, - } - - self._swap_mapping: Dict[int, int] = {} - self._null_block: Optional[Block] = None - - self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} - for _, allocator in self._allocators.items(): - for block_id in allocator.all_block_ids: - self._block_ids_to_allocator[block_id] = allocator - - def allocate_or_get_null_block(self) -> Block: - if self._null_block is None: - self._null_block = NullBlock( - self.allocate_mutable_block(None, Device.GPU)) - return self._null_block - - def allocate_mutable_block(self, - prev_block: Optional[Block], - device: Device, - extra_hash: Optional[int] = None) -> Block: - """Allocates a new mutable block on the specified device. - - Args: - prev_block (Optional[Block]): The previous block to in the sequence. - Used for prefix hashing. - device (Device): The device on which to allocate the new block. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefix caching block. - - Returns: - Block: The newly allocated mutable block. - """ - return self._allocators[device].allocate_mutable_block( - prev_block, extra_hash=extra_hash) - - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device, - extra_hash: Optional[int] = None) -> List[Block]: - """Allocates a new group of immutable blocks with the provided block - token IDs on the specified device. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - Used for prefix hashing. - block_token_ids (List[int]): The list of block token IDs to be - stored in the new blocks. - device (Device): The device on which to allocate the new block. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefix caching block. - - Returns: - List[Block]: The newly allocated list of immutable blocks - containing the provided block token IDs. - """ - return self._allocators[device].allocate_immutable_blocks( - prev_block, block_token_ids, extra_hash=extra_hash) - - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> Block: - """Allocates a new immutable block with the provided token IDs on the - specified device. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - Used for prefix hashing. - token_ids (List[int]): The list of token IDs to be stored in the new - block. - device (Device): The device on which to allocate the new block. - extra_hash (Optional[int]): The hash value of additional - factors, such as adapters, that influence the block hash - in the prefix caching block. - - Returns: - Block: The newly allocated immutable block containing the provided - token IDs. - """ - return self._allocators[device].allocate_immutable_block( - prev_block, token_ids, extra_hash=extra_hash) - - def free(self, block: Block) -> None: - """Frees the memory occupied by the given block. - - Args: - block (Block): The block to be freed. - """ - # Null block should never be freed - if isinstance(block, NullBlock): - return - block_id = block.block_id - assert block_id is not None - allocator = self._block_ids_to_allocator[block_id] - allocator.free(block) - - def fork(self, last_block: Block) -> List[Block]: - """Creates a new sequence of blocks that shares the same underlying - memory as the original sequence. - - Args: - last_block (Block): The last block in the original sequence. - - Returns: - List[Block]: A new list of blocks that shares the same memory as the - original sequence. - """ - # do not attempt to fork the null block - assert not isinstance(last_block, NullBlock) - block_id = last_block.block_id - assert block_id is not None - allocator = self._block_ids_to_allocator[block_id] - return allocator.fork(last_block) - - def get_num_free_blocks(self, device: Device) -> int: - """Returns the number of free blocks available on the specified device. - - Args: - device (Device): The device for which to query the number of free - blocks. AssertionError is raised if None is passed. - - Returns: - int: The number of free blocks available on the specified device. - """ - return self._allocators[device].get_num_free_blocks() - - def get_num_total_blocks(self, device: Device) -> int: - return self._allocators[device].get_num_total_blocks() - - def get_physical_block_id(self, device: Device, absolute_id: int) -> int: - """Returns the zero-offset block id on certain device given the - absolute block id. - - Args: - device (Device): The device for which to query relative block id. - absolute_id (int): The absolute block id for the block in - whole allocator. - - Returns: - int: The zero-offset block id on certain device. - """ - return self._allocators[device].get_physical_block_id(absolute_id) - - def swap(self, blocks: List[Block], src_device: Device, - dst_device: Device) -> Dict[int, int]: - """Execute the swap for the given blocks from source_device - on to dest_device, save the current swap mapping and append - them to the accumulated `self._swap_mapping` for each - scheduling move. - - Args: - blocks: List of blocks to be swapped. - src_device (Device): Device to swap the 'blocks' from. - dst_device (Device): Device to swap the 'blocks' to. - - Returns: - Dict[int, int]: Swap mapping from source_device - on to dest_device. - """ - src_block_ids = [block.block_id for block in blocks] - self._allocators[src_device].swap_out(blocks) - self._allocators[dst_device].swap_in(blocks) - dst_block_ids = [block.block_id for block in blocks] - - current_swap_mapping: Dict[int, int] = {} - for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): - if src_block_id is not None and dst_block_id is not None: - self._swap_mapping[src_block_id] = dst_block_id - current_swap_mapping[src_block_id] = dst_block_id - return current_swap_mapping - - def get_num_full_blocks_touched(self, blocks: List[Block], - device: Device) -> int: - """Returns the number of full blocks that will be touched by - swapping in/out the given blocks on to the 'device'. - - Args: - blocks: List of blocks to be swapped. - device (Device): Device to swap the 'blocks' on. - - Returns: - int: the number of full blocks that will be touched by - swapping in/out the given blocks on to the 'device'. - Non full blocks are ignored when deciding the number - of blocks to touch. - """ - return self._allocators[device].get_num_full_blocks_touched(blocks) - - def clear_copy_on_writes(self) -> List[Tuple[int, int]]: - """Clears the copy-on-write (CoW) state and returns the mapping of - source to destination block IDs. - - Returns: - List[Tuple[int, int]]: A list mapping source block IDs to - destination block IDs. - """ - # CoW only supported on GPU - device = Device.GPU - return self._allocators[device].clear_copy_on_writes() - - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - """Mark blocks as accessed, only use for prefix caching.""" - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].mark_blocks_as_accessed(block_ids, now) - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - """Mark blocks as accessed, only use for prefix caching.""" - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].mark_blocks_as_computed(block_ids) - - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].get_common_computed_block_ids( - computed_seq_block_ids) - - @property - def all_block_ids(self) -> FrozenSet[int]: - return frozenset(self._block_ids_to_allocator.keys()) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - assert device in self._allocators - return self._allocators[device].get_prefix_cache_hit_rate() - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache for specified or all devices.""" - if device: - return self._allocators[device].reset_prefix_cache() - success = True - for allocator in self._allocators.values(): - success = success and allocator.reset_prefix_cache() - return success - - def get_and_reset_swaps(self) -> List[Tuple[int, int]]: - """Returns and clears the mapping of source to destination block IDs. - Will be called after every swapping operations for now, and after every - schedule when BlockManagerV2 become default. Currently not useful. - - Returns: - List[Tuple[int, int]]: A mapping of source to destination block IDs. - """ - mapping = self._swap_mapping.copy() - self._swap_mapping.clear() - return list(mapping.items()) - - def find_cached_blocks_prefix( - self, - block_hashes: List[int], - device: Device = Device.GPU, - ) -> List[int]: - return self._allocators[device].find_cached_blocks_prefix(block_hashes) - - -class NullBlock(Block): - """ - Null blocks are used as a placeholders for KV cache blocks that have - been dropped due to sliding window. - This implementation just wraps an ordinary block and prevents it from - being modified. It also allows for testing if a block is NullBlock - via isinstance(). - """ - - def __init__(self, proxy: Block): - super().__init__() - self._proxy = proxy - - def append_token_ids(self, token_ids: List[BlockId]): - raise ValueError("null block should not be modified") - - @property - def block_id(self): - return self._proxy.block_id - - @block_id.setter - def block_id(self, value: Optional[BlockId]): - raise ValueError("null block should not be modified") - - @property - def token_ids(self) -> List[BlockId]: - return self._proxy.token_ids - - @property - def num_tokens_total(self) -> int: - raise NotImplementedError( - "num_tokens_total is not used for null block") - - @property - def num_empty_slots(self) -> BlockId: - return self._proxy.num_empty_slots - - @property - def is_full(self): - return self._proxy.is_full - - @property - def prev_block(self): - return self._proxy.prev_block - - @property - def extra_hash(self): - return None - - @property - def computed(self): - return self._proxy.computed - - @computed.setter - def computed(self, value): - self._proxy.computed = value - - @property - def last_accessed(self) -> float: - return self._proxy.last_accessed - - @last_accessed.setter - def last_accessed(self, last_accessed_ts: float): - self._proxy.last_accessed = last_accessed_ts - - @property - def content_hash(self): - return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py deleted file mode 100644 index 1a05881f7c005..0000000000000 --- a/vllm/core/block/interfaces.py +++ /dev/null @@ -1,319 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple - -from vllm.utils import Device - -BlockId = int - - -class Block(ABC): - - @abstractmethod - def append_token_ids(self, token_ids: List[int]) -> None: - pass - - @property - @abstractmethod - def block_id(self) -> Optional[int]: - pass - - @block_id.setter - @abstractmethod - def block_id(self, value: Optional[int]) -> None: - """NOTE: Do not use this API outside Block.""" - self._block_id = value - - @property - @abstractmethod - def token_ids(self) -> List[int]: - pass - - @property - @abstractmethod - def num_tokens_total(self) -> int: - """The number of tokens till the current block (inclusive) - """ - pass - - @property - @abstractmethod - def num_empty_slots(self) -> int: - pass - - @property - @abstractmethod - def is_full(self) -> bool: - pass - - @property - @abstractmethod - def prev_block(self) -> Optional["Block"]: - pass - - @property - @abstractmethod - def extra_hash(self) -> Optional[int]: - return None - - @property - @abstractmethod - def computed(self) -> bool: - raise NotImplementedError - - @computed.setter - @abstractmethod - def computed(self, value) -> bool: - """Should be only used by PrefixCacingAllocator""" - raise NotImplementedError - - @property - @abstractmethod - def last_accessed(self) -> float: - raise NotImplementedError - - @last_accessed.setter - @abstractmethod - def last_accessed(self, last_accessed_ts: float): - raise NotImplementedError - - class Factory(Protocol): - - @abstractmethod - def __call__( - self, - prev_block: Optional["Block"], - token_ids: List[int], - block_size: int, - allocator: "BlockAllocator", - block_id: Optional[int] = None, - computed: bool = False, - extra_hash: Optional[int] = None, - ) -> "Block": - pass - - @property - @abstractmethod - def content_hash(self) -> Optional[int]: - """Return the content-based hash of the current block, or None if it is - not yet defined or not supported. - - For the content-based hash to be defined, the current block must be - full. - """ - return None - - -class BlockAllocator(ABC): - - @abstractmethod - def allocate_mutable_block(self, prev_block: Optional[Block], - extra_hash: Optional[int]) -> Block: - pass - - @abstractmethod - def allocate_immutable_block(self, prev_block: Optional[Block], - token_ids: List[int], - extra_hash: Optional[int]) -> Block: - pass - - @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - extra_hash: Optional[int]) -> List[Block]: - pass - - @abstractmethod - def free(self, block: Block) -> None: - pass - - @abstractmethod - def fork(self, last_block: Block) -> List[Block]: - pass - - @abstractmethod - def get_num_total_blocks(self) -> int: - pass - - @abstractmethod - def get_num_free_blocks(self) -> int: - pass - - @abstractmethod - def get_physical_block_id(self, absolute_id: int) -> int: - pass - - @abstractmethod - def swap_out(self, blocks: List[Block]) -> None: - pass - - @abstractmethod - def swap_in(self, blocks: List[Block]) -> None: - pass - - @property - @abstractmethod - def all_block_ids(self) -> FrozenSet[int]: - pass - - @abstractmethod - def clear_copy_on_writes(self) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - pass - - @abstractmethod - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - pass - - @abstractmethod - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - pass - - @abstractmethod - def cow_block_if_not_appendable(self, block: Block) -> BlockId: - """NOTE: This should not be used besides Block""" - pass - - @abstractmethod - def promote_to_immutable_block(self, block: Block) -> BlockId: - """NOTE: This should not be used besides Block""" - pass - - @abstractmethod - def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - @abstractmethod - def reset_prefix_cache(self) -> bool: - """Reset prefix cache.""" - pass - - class NoFreeBlocksError(ValueError): - pass - - @abstractmethod - def find_cached_blocks_prefix( - self, - block_hashes: List[int], - ) -> List[int]: - pass - - -class DeviceAwareBlockAllocator(ABC): - - @abstractmethod - def allocate_mutable_block(self, - prev_block: Optional[Block], - device: Device, - extra_hash: Optional[int] = None) -> Block: - pass - - @abstractmethod - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Device, - extra_hash: Optional[int] = None) -> Block: - pass - - @abstractmethod - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device, - extra_hash: Optional[int] = None, - ) -> List[Block]: - pass - - @abstractmethod - def get_num_free_blocks(self, device: Device) -> int: - pass - - @abstractmethod - def get_num_total_blocks(self, device: Device) -> int: - pass - - @abstractmethod - def free(self, block: Block) -> None: - pass - - @abstractmethod - def fork(self, last_block: Block) -> List[Block]: - pass - - @property - @abstractmethod - def all_block_ids(self) -> FrozenSet[int]: - pass - - @abstractmethod - def clear_copy_on_writes(self) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - pass - - @abstractmethod - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - pass - - @abstractmethod - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - pass - - @abstractmethod - def get_num_full_blocks_touched(self, blocks: List[Block], - device: Device) -> int: - pass - - @abstractmethod - def swap(self, blocks: List[Block], src_device: Device, - dst_device: Device) -> Dict[int, int]: - pass - - @abstractmethod - def get_physical_block_id(self, device: Device, absolute_id: int) -> int: - pass - - @abstractmethod - def allocate_or_get_null_block(self) -> Block: - """ - Null blocks are used as a placeholders for KV cache blocks that have - been dropped due to sliding window. - There is at most one null block per allocator. - """ - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self, device: Device) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - @abstractmethod - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache.""" - pass - - @abstractmethod - def find_cached_blocks_prefix( - self, - block_hashes: List[int], - device: Device = Device.GPU, - ) -> List[int]: - pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py deleted file mode 100644 index ae876d131eb66..0000000000000 --- a/vllm/core/block/naive_block.py +++ /dev/null @@ -1,466 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections import deque -from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union - -from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, - get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device - -Refcount = int - - -class NaiveBlockAllocator(BlockAllocator): - """A simple block allocator that manages blocks of memory without prefix - caching. - - Args: - create_block (Block.Factory): A factory function for creating new - blocks. This is used when a NaiveBlockAllocator is composed within - a prefix caching allocator -- the naive block allocator must - construct prefix caching blocks (but shouldn't know anything else - about them). - num_blocks (int): The total number of blocks to manage. - block_size (int): The size of each block in tokens. - block_ids (Optional[Iterable[int]], optional): An optional iterable of - block IDs. If not provided, block IDs will be assigned sequentially - from 0 to num_blocks - 1. - """ - - def __init__( - self, - create_block: Block.Factory, - num_blocks: int, - block_size: int, - block_ids: Optional[Iterable[int]] = None, - block_pool: Optional[BlockPool] = None, - ): - if block_ids is None: - block_ids = range(num_blocks) - - self._free_block_indices: Deque[BlockId] = deque(block_ids) - self._all_block_indices = frozenset(block_ids) - assert len(self._all_block_indices) == num_blocks - - self._refcounter = RefCounter( - all_block_indices=self._free_block_indices) - self._block_size = block_size - - self._cow_tracker = CopyOnWriteTracker( - refcounter=self._refcounter.as_readonly()) - - if block_pool is None: - extra_factor = 4 - # Pre-allocate "num_blocks * extra_factor" block objects. - # The "* extra_factor" is a buffer to allow more block objects - # than physical blocks - self._block_pool = BlockPool(self._block_size, create_block, self, - num_blocks * extra_factor) - else: - # In this case, the block pool is provided by the caller, - # which means that there is most likely a need to share - # a block pool between allocators - self._block_pool = block_pool - - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates a new immutable block with the given token IDs, linked to - the previous block. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. If - None, then the block to be allocated is the first block in the - sequence. - token_ids (List[int]): The token IDs to be stored in the new block. - - Returns: - Block: The newly allocated immutable block. - """ - assert device is None - block = self.allocate_mutable_block(prev_block=prev_block) - block.append_token_ids(token_ids) - return block - - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> List[Block]: - assert device is None - num_blocks = len(block_token_ids) - - block_ids = [] - for i in range(num_blocks): - block_ids.append(self._allocate_block_id()) - - blocks = [] - for i in range(num_blocks): - prev_block = self._block_pool.init_block( - prev_block=prev_block, - token_ids=block_token_ids[i], - block_size=self._block_size, - physical_block_id=block_ids[i]) - blocks.append(prev_block) - - return blocks - - def allocate_mutable_block(self, - prev_block: Optional[Block], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates a new mutable block, linked to the previous block. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. If - None, then the block to be allocated is the first block in the - sequence. - - Returns: - Block: The newly allocated mutable block. - """ - assert device is None - block_id = self._allocate_block_id() - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=[], - block_size=self._block_size, - physical_block_id=block_id) - return block - - def _allocate_block_id(self) -> BlockId: - if not self._free_block_indices: - raise BlockAllocator.NoFreeBlocksError() - - block_id = self._free_block_indices.popleft() - self._refcounter.incr(block_id) - return block_id - - def _free_block_id(self, block: Union[Block, BlockId]) -> None: - if isinstance(block, Block): - block_id = block.block_id - block.block_id = None - else: - block_id = block - assert block_id is not None - - refcount = self._refcounter.decr(block_id) - if refcount == 0: - self._free_block_indices.appendleft(block_id) - - def free(self, block: Block, keep_block_object: bool = False) -> None: - # Release the physical block id - self._free_block_id(block) - - # Release the block object - if not keep_block_object: - self._block_pool.free_block(block) - - def free_block_id(self, block_id: BlockId) -> None: - self._free_block_id(block_id) - - def fork(self, last_block: Block) -> List[Block]: - """Creates a new sequence of blocks that shares the same underlying - memory as the original sequence. - - Args: - last_block (Block): The last block in the original sequence. - - Returns: - List[Block]: The new sequence of blocks that shares the same memory - as the original sequence. - """ - source_blocks = get_all_blocks_recursively(last_block) - - forked_blocks: List[Block] = [] - prev_block = None - for block in source_blocks: - - # Increment refcount for each block. - assert block.block_id is not None - refcount = self._refcounter.incr(block.block_id) - assert refcount != 1, "can't fork freed block" - - forked_block = self._block_pool.init_block( - prev_block=prev_block, - token_ids=block.token_ids, - block_size=self._block_size, - physical_block_id=block.block_id) - - forked_blocks.append(forked_block) - prev_block = forked_blocks[-1] - - return forked_blocks - - def get_num_free_blocks(self) -> int: - return len(self._free_block_indices) - - def get_num_total_blocks(self) -> int: - return len(self._all_block_indices) - - def get_physical_block_id(self, absolute_id: int) -> int: - """Returns the zero-offset block id on certain block allocator - given the absolute block id. - - Args: - absolute_id (int): The absolute block id for the block - in whole allocator. - - Returns: - int: The zero-offset block id on certain device. - """ - return sorted(self._all_block_indices).index(absolute_id) - - @property - def refcounter(self): - return self._refcounter - - @property - def all_block_ids(self) -> FrozenSet[int]: - return self._all_block_indices - - def cow_block_if_not_appendable(self, block: Block) -> BlockId: - """Performs a copy-on-write operation on the given block if it is not - appendable. - - Args: - block (Block): The block to check for copy-on-write. - - Returns: - BlockId: The block index of the new block if a copy-on-write - operation was performed, or the original block index if - no copy-on-write was necessary. - """ - src_block_id = block.block_id - assert src_block_id is not None - - if self._cow_tracker.is_appendable(block): - return src_block_id - - self._free_block_id(block) - trg_block_id = self._allocate_block_id() - - self._cow_tracker.record_cow(src_block_id, trg_block_id) - - return trg_block_id - - def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: - """Returns the copy-on-write source->destination mapping and clears it. - - Returns: - List[Tuple[BlockId, BlockId]]: A list mapping source - block indices to destination block indices. - """ - return self._cow_tracker.clear_cows() - - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - """Mark blocks as accessed, used in prefix caching. - - Since the naive allocator does not implement prefix caching, we do - nothing. - """ - pass - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - """Mark blocks as computed, used in prefix caching. - - Since the naive allocator does not implement prefix caching, we do - nothing. - """ - pass - - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - """Determine blocks that can be skipped in prefill. - - Since the naive allocator does not support prefix caching, always return - an empty list. - """ - return [] - - def promote_to_immutable_block(self, block: Block) -> BlockId: - raise NotImplementedError("There is no promotion for naive blocks") - - def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: - """Returns the number of full blocks that will be touched by - swapping in/out. - - Args: - blocks: List of blocks to be swapped. - Returns: - int: the number of full blocks that will be touched by - swapping in/out the given blocks. Non full blocks are ignored - when deciding the number of blocks to touch. - """ - # NOTE: for naive block, we use set to eliminate common blocks among - # seqs, also we compare the empty slots in the mutable blocks with - # lookahead slots to get the number of unique new block that are - # needed. - old_block_set = set() - for block in blocks: - if block.is_full: - old_block_set.add(block) - return len(old_block_set) - - def swap_out(self, blocks: List[Block]) -> None: - for block in blocks: - self._free_block_id(block) - - def swap_in(self, blocks: List[Block]) -> None: - for block in blocks: - # Here we allocate either immutable or mutable block and then - # extract its block_id. Note that the block object is released - # and the block_id is assigned to "block" to allow reusing the - # existing "block" object - if block.is_full: - tmp_block = self.allocate_immutable_block( - prev_block=block.prev_block, token_ids=block.token_ids) - else: - tmp_block = self.allocate_mutable_block( - prev_block=block.prev_block) - tmp_block.append_token_ids(block.token_ids) - - block_id = tmp_block.block_id - tmp_block.block_id = None - self._block_pool.free_block(tmp_block) - - block.block_id = block_id # Assign block_id - - def get_prefix_cache_hit_rate(self) -> float: - return -1 - - def reset_prefix_cache(self) -> bool: - """No prefix cache for naive block allocator.""" - return True - - def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: - # Not applicable for naive block allocator. - return [] - - -class NaiveBlock(Block): - """An implementation of the Block class that does not support prefix - caching. - - The NaiveBlock class represents a block of token IDs with a fixed size. It - provides methods for appending token IDs to the block and manages copy-on - -write operations when necessary. - - Args: - prev_block (Block): The previous block in the sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. - block_size (int): The maximum number of token IDs that can be stored in - the block. - allocator (BlockAllocator): The block allocator associated with this - block. - block_id (Optional[int], optional): The physical block index - of this block. Defaults to None, which means no allocation has been - made. - _cow_target (Optional[Block], optional): The copy-on-write target block. - If not provided, it defaults to self. - """ - - def __init__(self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - allocator: BlockAllocator, - block_id: Optional[int] = None, - _cow_target: Optional[Block] = None, - extra_hash: Optional[int] = None): - self._token_ids: List[int] = [] - self._block_size = block_size - self._prev_block = prev_block - self._block_id = block_id - self._allocator = allocator - self._cow_target = _cow_target if _cow_target is not None else self - - self._append_token_ids_no_cow(token_ids) - - def append_token_ids(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block and performs a - copy-on-write if necessary. - - Args: - token_ids (Optional[List[int]]): The token IDs to be appended - to the block. - """ - self._append_token_ids_no_cow(token_ids) - - if self._block_id is not None: - self._block_id = (self._allocator.cow_block_if_not_appendable( - self._cow_target)) - - def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block - - Args: - token_ids (List[int]): The token IDs to be appended to the block. - """ - if len(token_ids) == 0: - return - - assert len(token_ids) <= self.num_empty_slots - - self._token_ids.extend(token_ids) - - @property - def computed(self) -> bool: - raise NotImplementedError - - @computed.setter - def computed(self, value) -> None: - raise NotImplementedError - - @property - def last_accessed(self) -> float: - raise NotImplementedError - - @last_accessed.setter - def last_accessed(self, last_accessed_ts: float): - raise NotImplementedError - - @property - def block_id(self) -> Optional[int]: - return self._block_id - - @block_id.setter - def block_id(self, value: Optional[int]) -> None: - self._block_id = value - - @property - def is_full(self) -> bool: - return self.num_empty_slots == 0 - - @property - def num_empty_slots(self) -> int: - return self._block_size - len(self.token_ids) - - @property - def token_ids(self) -> List[int]: - return self._token_ids - - @property - def num_tokens_total(self) -> int: - raise NotImplementedError( - "num_tokens_total is not used for naive block") - - @property - def block_size(self) -> int: - return self._block_size - - @property - def prev_block(self) -> Optional["Block"]: - return self._prev_block - - @property - def extra_hash(self): - return None - - @property - def content_hash(self) -> Optional[int]: - return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py deleted file mode 100644 index a21d69323abbc..0000000000000 --- a/vllm/core/block/prefix_caching_block.py +++ /dev/null @@ -1,1135 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Token blocks.""" -import sys -from bisect import bisect_left -from os.path import commonprefix -from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, - Tuple) - -from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, - get_all_blocks_recursively) -from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, - DeviceAwareBlockAllocator) -from vllm.core.block.naive_block import (BlockPool, NaiveBlock, - NaiveBlockAllocator) -from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor -from vllm.logger import init_logger -from vllm.sequence import Sequence - -PrefixHash = int - -# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME -# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, -# then we know this block hasn't been accessed yet. -_DEFAULT_LAST_ACCESSED_TIME = -1 - -logger = init_logger(__name__) - - -class BlockTracker: - """Used to track the status of a block inside the prefix caching allocator - """ - __slots__ = ("active", "last_accessed", "computed") - - def reset(self): - self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME - self.computed: bool = False - - def __init__(self): - self.active: bool = False - self.reset() - - def enable(self): - assert not self.active - self.active = True - self.reset() - - def disable(self): - assert self.active - self.active = False - self.reset() - - -class PrefixCachingBlockAllocator(BlockAllocator): - """A block allocator that implements prefix caching. - - The PrefixCachingBlockAllocator maintains a cache of blocks based on their - content hash. It reuses blocks with the same content hash to avoid redundant - memory allocation. The allocator also supports copy-on-write operations. - - Args: - num_blocks (int): The total number of blocks to manage. - block_size (int): The size of each block in tokens. - block_ids (Optional[Iterable[int]], optional): An optional iterable of - block IDs. If not provided, block IDs will be assigned sequentially - from 0 to num_blocks - 1. - """ - - # Note that we use 'None' as a string here instead of None because - # as of Python 3.12, hash(None) returns a constant predictable value. - # This could possibly make it easier to find and exploit hash - # collisions. 'None' as a string will be hashed differently per process, - # but consistently within the same process. This is the same as the - # behavior of None prior to Python 3.12. - _none_hash: int = hash('None') - - # Implements Block.Factory. - def __init__( - self, - num_blocks: int, - block_size: int, - block_ids: Optional[Iterable[int]] = None, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU, - ): - if block_ids is None: - block_ids = range(num_blocks) - - self._block_size = block_size - - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash will be in this dict, even if they have refcount 0. - self._cached_blocks: Dict[PrefixHash, BlockId] = {} - - # A list of immutable block IDs that have been touched by scheduler - # and should be marked as computed after an entire batch of sequences - # are scheduled. - self._touched_blocks: Set[BlockId] = set() - - # Used to track status of each physical block id - self._block_tracker: Dict[BlockId, BlockTracker] = {} - for block_id in block_ids: - self._block_tracker[block_id] = BlockTracker() - - # Pre-allocate "num_blocks * extra_factor" block objects. - # The "* extra_factor" is a buffer to allow more block objects - # than physical blocks - extra_factor = 4 - self._block_pool = BlockPool(self._block_size, self._create_block, - self, num_blocks * extra_factor) - - # An allocator for blocks that do not have prefix hashes. - self._hashless_allocator = NaiveBlockAllocator( - create_block=self._create_block, # type: ignore - num_blocks=num_blocks, - block_size=block_size, - block_ids=block_ids, - block_pool=self._block_pool, # Share block pool here - ) - - # Evitor used to maintain how we want to handle those computed blocks - # if we find memory pressure is high. - self.eviction_policy = eviction_policy - self.evictor: Evictor = make_evictor(self.eviction_policy) - - # We share the refcounter between allocators. This allows us to promote - # blocks originally allocated in the hashless allocator to immutable - # blocks. - self._refcounter = self._hashless_allocator.refcounter - - self._cow_tracker = CopyOnWriteTracker( - refcounter=self._refcounter.as_readonly()) - - self.metric_data = CacheMetricData() - - def _create_block( - self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - allocator: BlockAllocator, - block_id: Optional[int] = None, - computed: bool = False, - extra_hash: Optional[int] = None, - ) -> Block: - # Bind block to self. - allocator = self - - return PrefixCachingBlock( - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=allocator, - computed=computed, - extra_hash=extra_hash, - ) - - def allocate_immutable_block(self, - prev_block: Optional[Block], - token_ids: List[int], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates an immutable block with the given token IDs, reusing cached - blocks if possible. - - Args: - prev_block (Optional[Block]): The previous block in the sequence. - token_ids (List[int]): The token IDs to be stored in the block. - - Returns: - Block: The allocated immutable block. - """ - assert device is None - assert_prefix_caching_block_or_none(prev_block) - - # First, try to create a block that points to cached data - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=token_ids, - block_size=self._block_size, - physical_block_id=None, - extra_hash=extra_hash) - assert block.content_hash is not None - - cached_block_id = self._cached_blocks.get(block.content_hash, None) - if cached_block_id is not None: - self.metric_data.query(hit=True) - block.block_id = cached_block_id - self._incr_refcount_cached_block(block) - return block - self.metric_data.query(hit=False) - self._block_pool.free_block(block) - - # No cached block => Allocate a new block - block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash) - block.append_token_ids(token_ids) - return block - - def allocate_immutable_blocks( - self, - prev_block: Optional[Block], - block_token_ids: List[List[int]], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> List[Block]: - blocks = [] - for token_ids in block_token_ids: - prev_block = self.allocate_immutable_block(prev_block=prev_block, - token_ids=token_ids, - device=device, - extra_hash=extra_hash) - blocks.append(prev_block) - return blocks - - def allocate_mutable_block(self, - prev_block: Optional[Block], - extra_hash: Optional[int] = None, - device: Optional[Device] = None) -> Block: - """Allocates a mutable block. If there are no free blocks, this will - evict unused cached blocks. - - Args: - prev_block (Block): The previous block in the sequence. - None is not allowed unlike it is super class. - - Returns: - Block: The allocated mutable block. - """ - assert device is None - assert_prefix_caching_block_or_none(prev_block) - - block_id = self._allocate_block_id() - block = self._block_pool.init_block(prev_block=prev_block, - token_ids=[], - block_size=self._block_size, - physical_block_id=block_id, - extra_hash=extra_hash) - assert not block.computed - assert block.content_hash is None - return block - - def _incr_refcount_cached_block(self, block: Block) -> None: - # Set this block to be "computed" since it is pointing to a - # cached block id (which was already computed) - block.computed = True - - block_id = block.block_id - assert block_id is not None - - refcount = self._refcounter.incr(block_id) - if refcount == 1: - # In case a cached block was evicted, restore its tracking - if block_id in self.evictor: - self.evictor.remove(block_id) - - self._track_block_id(block_id, computed=True) - - def _decr_refcount_cached_block(self, block: Block) -> None: - # Ensure this is immutable/cached block - assert block.content_hash is not None - - block_id = block.block_id - assert block_id is not None - - refcount = self._refcounter.decr(block_id) - if refcount > 0: - block.block_id = None - return - else: - assert refcount == 0 - - # No longer used - assert block.content_hash in self._cached_blocks - - # Add the cached block to the evictor - # (This keeps the cached block around so it can be reused) - self.evictor.add(block_id, block.content_hash, block.num_tokens_total, - self._block_tracker[block_id].last_accessed) - - # Stop tracking the block - self._untrack_block_id(block_id) - - block.block_id = None - - def _decr_refcount_hashless_block(self, block: Block) -> None: - block_id = block.block_id - assert block_id is not None - - # We may have a fork case where block is shared, - # in which case, we cannot remove it from tracking - refcount = self._refcounter.get(block_id) - if refcount == 1: - self._untrack_block_id(block_id) - - # Decrement refcount of the block_id, but do not free the block object - # itself (will be handled by the caller) - self._hashless_allocator.free(block, keep_block_object=True) - - def _allocate_block_id(self) -> BlockId: - """First tries to allocate a block id from the hashless allocator, - and if there are no blocks, then tries to evict an unused cached block. - """ - hashless_block_id = self._maybe_allocate_hashless_block_id() - if hashless_block_id is not None: - return hashless_block_id - - evicted_block_id = self._maybe_allocate_evicted_block_id() - if evicted_block_id is not None: - return evicted_block_id - - # No block available in hashless allocator, nor in unused cache blocks. - raise BlockAllocator.NoFreeBlocksError() - - def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: - try: - # Allocate mutable block and extract its block_id - block = self._hashless_allocator.allocate_mutable_block( - prev_block=None) - block_id = block.block_id - self._block_pool.free_block(block) - - self._track_block_id(block_id, computed=False) - return block_id - except BlockAllocator.NoFreeBlocksError: - return None - - def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: - if self.evictor.num_blocks == 0: - return None - - # Here we get an evicted block, which is only added - # into evictor if its ref counter is 0 - # and since its content would be changed, we need - # to remove it from _cached_blocks's tracking list - block_id, content_hash_to_evict = self.evictor.evict() - - # Sanity checks - assert content_hash_to_evict in self._cached_blocks - _block_id = self._cached_blocks[content_hash_to_evict] - assert self._refcounter.get(_block_id) == 0 - assert _block_id == block_id - - self._cached_blocks.pop(content_hash_to_evict) - - self._refcounter.incr(block_id) - self._track_block_id(block_id, computed=False) - - return block_id - - def _free_block_id(self, block: Block) -> None: - """Decrements the refcount of the block. The block may be in two - possible states: (1) immutable/cached or (2) mutable/hashless. - In the first case, the refcount is decremented directly and the block - may be possibly added to the evictor. In other case, hashless - allocator free(..) with keep_block_object=True is called to only free - the block id (since the block object may be reused by the caller) - """ - block_id = block.block_id - assert block_id is not None, "Freeing unallocated block is undefined" - - if block.content_hash is not None: - # Immutable: This type of block is always cached, and we want to - # keep it in the evictor for future reuse - self._decr_refcount_cached_block(block) - else: - # Mutable: This type of block is not cached, so we release it - # directly to the hashless allocator - self._decr_refcount_hashless_block(block) - - assert block.block_id is None - - def free(self, block: Block, keep_block_object: bool = False) -> None: - """Release the block (look at free_block_id(..) docs) - """ - # Release the physical block index - self._free_block_id(block) - - # Release the block object to the pool - if not keep_block_object: - self._block_pool.free_block(block) - - def fork(self, last_block: Block) -> List[Block]: - """Creates a new sequence of blocks that shares the same underlying - memory as the original sequence. - - Args: - last_block (Block): The last block in the original sequence. - - Returns: - List[Block]: The new sequence of blocks that shares the same memory - as the original sequence. - """ - source_blocks = get_all_blocks_recursively(last_block) - - forked_blocks: List[Block] = [] - prev_block = None - for block in source_blocks: - block_id = block.block_id - assert block_id is not None - - refcount = self._refcounter.incr(block_id) - assert refcount != 1, "can't fork free'd block_id = {}".format( - block_id) - - forked_block = self._block_pool.init_block( - prev_block=prev_block, - token_ids=block.token_ids, - block_size=self._block_size, - physical_block_id=block_id, - extra_hash=block.extra_hash) - - forked_blocks.append(forked_block) - prev_block = forked_blocks[-1] - - return forked_blocks - - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: - assert device is None - # The number of free blocks is the number of hashless free blocks - # plus the number of blocks evictor could free from its list. - return self._hashless_allocator.get_num_free_blocks( - ) + self.evictor.num_blocks - - def get_num_total_blocks(self) -> int: - return self._hashless_allocator.get_num_total_blocks() - - def get_physical_block_id(self, absolute_id: int) -> int: - """Returns the zero-offset block id on certain block allocator - given the absolute block id. - - Args: - absolute_id (int): The absolute block id for the block - in whole allocator. - - Returns: - int: The rzero-offset block id on certain device. - """ - return sorted(self.all_block_ids).index(absolute_id) - - @property - def all_block_ids(self) -> FrozenSet[int]: - return self._hashless_allocator.all_block_ids - - def get_prefix_cache_hit_rate(self) -> float: - return self.metric_data.get_hit_rate() - - def reset_prefix_cache(self) -> bool: - """Reset prefix cache. This function may be used in RLHF - flows to invalid prefix caching after the weights are updated, - or used for resetting prefix caching status for benchmarking. - - Returns: - bool: True if the prefix cache is successfully reset, - False otherwise. - """ - num_used_blocks = (self.get_num_total_blocks() - - self.get_num_free_blocks()) - if num_used_blocks > 0: - logger.warning( - "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) - return False - - # Free all blocks in the evictor. - while (block_id := - self._maybe_allocate_evicted_block_id()) is not None: - self._hashless_allocator.free_block_id(block_id) - - # Should not have any cached blocks because all blocks are evicted. - assert not self._cached_blocks - - # Reset the evictor. - self.evictor = make_evictor(self.eviction_policy) - - # Reset the block tracker. - for block_id in self._block_tracker: - self._block_tracker[block_id] = BlockTracker() - - # Reset the metrics. - self.metric_data = CacheMetricData() - - logger.info("Successfully reset prefix cache") - return True - - def is_block_cached(self, block: Block) -> bool: - assert block.content_hash is not None - return block.content_hash in self._cached_blocks - - def promote_to_immutable_block(self, block: Block) -> BlockId: - """Once a mutable block is full, it can be promoted to an immutable - block. This means that its content can be referenced by future blocks - having the same prefix. - - Note that if we already have a cached block with the same content, we - will replace the newly-promoted block's mapping with the existing cached - block id. - - Args: - block: The mutable block to be promoted. - - Returns: - BlockId: Either the original block index, or the block index of - the previously cached block matching the same content. - """ - # Ensure block can be promoted - assert block.content_hash is not None - assert block.block_id is not None - assert self._refcounter.get(block.block_id) > 0 - - if block.content_hash not in self._cached_blocks: - # No cached content hash => Set this block as cached. - # Note that this block cannot be marked as computed yet - # because other sequences in the same batch cannot reuse - # this block. - self._cached_blocks[block.content_hash] = block.block_id - # Mark this block as touched so that it can be marked as - # computed after the entire batch of sequences are scheduled. - self._touched_blocks.add(block.block_id) - return block.block_id - - # Reuse the cached content hash - self._decr_refcount_hashless_block(block) - block.block_id = self._cached_blocks[block.content_hash] - - # Increment refcount of the cached block and (possibly) restore - # it from the evictor. - # Note that in this case, the block is marked as computed - self._incr_refcount_cached_block(block) - - return block.block_id - - def cow_block_if_not_appendable(self, block: Block) -> BlockId: - """Performs a copy-on-write operation on the given block if it is not - appendable. - - Args: - block (Block): The block to check for copy-on-write. - - Returns: - BlockId: The block index of the new block if a copy-on-write - operation was performed, or the original block index if - no copy-on-write was necessary. - """ - src_block_id = block.block_id - assert src_block_id is not None - - if self._cow_tracker.is_appendable(block): - return src_block_id - - self._free_block_id(block) - trg_block_id = self._allocate_block_id() - - self._cow_tracker.record_cow(src_block_id, trg_block_id) - - return trg_block_id - - def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: - """Returns the copy-on-write source->destination mapping and clears it. - - Returns: - List[Tuple[BlockId, BlockId]]: A list mapping source - block indices to destination block indices. - """ - return self._cow_tracker.clear_cows() - - def mark_blocks_as_accessed(self, block_ids: List[int], - now: float) -> None: - """Mark blocks as accessed, used in prefix caching. - - If the block is added into evictor, we need to update corresponding - info in evictor's metadata. - """ - - for block_id in block_ids: - if self._block_tracker[block_id].active: - self._block_tracker[block_id].last_accessed = now - elif block_id in self.evictor: - self.evictor.update(block_id, now) - else: - raise ValueError( - "Mark block as accessed which is not belonged to GPU") - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - # Mark all touched blocks as computed. - for block_id in self._touched_blocks: - self._block_tracker[block_id].computed = True - self._touched_blocks.clear() - - def _track_block_id(self, block_id: Optional[BlockId], - computed: bool) -> None: - assert block_id is not None - self._block_tracker[block_id].enable() - self._block_tracker[block_id].computed = computed - - def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: - assert block_id is not None - self._block_tracker[block_id].disable() - - def block_is_computed(self, block_id: int) -> bool: - if self._block_tracker[block_id].active: - return self._block_tracker[block_id].computed - else: - return block_id in self.evictor - - def get_common_computed_block_ids( - self, computed_seq_block_ids: List[List[int]]) -> List[int]: - """Return the block ids that are common for a given sequence group. - - Only those blocks that are immutable and already be marked - compyted would be taken consideration. - """ - - # NOTE We exclude the last block to avoid the case where the entire - # prompt is cached. This would cause erroneous behavior in model - # runner. - - # It returns a list of int although type annotation says list of string. - if len(computed_seq_block_ids) == 1: - return computed_seq_block_ids[0] - - return commonprefix([ - ids for ids in computed_seq_block_ids # type: ignore - if ids - ]) - - def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: - """Returns the number of full blocks that will be touched by - swapping in/out. - - Args: - blocks: List of blocks to be swapped. - Returns: - int: the number of full blocks that will be touched by - swapping in/out the given blocks. Non full blocks are ignored - when deciding the number of blocks to touch. - """ - num_touched_blocks: int = 0 - for block in blocks: - # If the block has a match in the cache and the cached - # block is not referenced, then we still count it as a - # touched block - if block.is_full and (not self.is_block_cached(block) or \ - (block.content_hash is not None and \ - self._cached_blocks[block.content_hash] in \ - self.evictor)): - num_touched_blocks += 1 - return num_touched_blocks - - def swap_out(self, blocks: List[Block]) -> None: - """Execute the swap out actions. Basically just free the - given blocks. - - Args: - blocks: List of blocks to be swapped out. - """ - for block in blocks: - self._free_block_id(block) - - def swap_in(self, blocks: List[Block]) -> None: - """Execute the swap in actions. Change the block id from - old allocator to current allocator for each block to finish - the block table update. - - Args: - blocks: List of blocks to be swapped in. - """ - for block in blocks: - # Here we allocate either immutable or mutable block and then - # extract its block_id. Note that the block object is released - # and the block_id is assigned to "block" to allow reusing the - # existing "block" object - if block.is_full: - tmp_block = self.allocate_immutable_block( - prev_block=block.prev_block, - token_ids=block.token_ids, - extra_hash=block.extra_hash) - else: - tmp_block = self.allocate_mutable_block( - prev_block=block.prev_block, extra_hash=block.extra_hash) - tmp_block.append_token_ids(block.token_ids) - - block_id = tmp_block.block_id - self._block_pool.free_block(tmp_block) - - block.block_id = block_id # Assign block_id - - def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: - """ - Given a list of block hashes, return the prefix of the block hashes that - are all cached. - - Since a block's block hash includes the hashes of all previous blocks, - and we only allocate/deallocate blocks in the entire sequence, so if a - block is cached, then all previous blocks are also cached. With this - property, we can use binary search to find the prefix of cached blocks. - - Args: - block_hashes (List[int]): The list of block hashes. - - Returns: - List[int]: The prefix of the `block_hashes` that are cached. - """ - - def _block_is_cached(block_hash: PrefixHash) -> bool: - if block_hash not in self._cached_blocks: - return False - - cached_block_id = self._cached_blocks[block_hash] - # We only consider the blocks that are marked as computed. - return self.block_is_computed(cached_block_id) - - def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: - - # python <= 3.10 don't have the key argument - if sys.version_info < (3, 10): - a = [key(e) for e in a] - return bisect_left(a, x) - else: - return bisect_left(a, x, key=key) - - # Look for the first block that's not cached, and returns the prefix - # i.e. blocks that are cached. - idx = _bisect_left(block_hashes, - True, - key=lambda x: not _block_is_cached(x)) - return block_hashes[:idx] - - -class PrefixCachingBlock(Block): - """A block implementation that supports prefix caching. - - The PrefixCachingBlock class represents a block of token IDs with prefix - caching capabilities. It wraps a NaiveBlock internally and provides - additional functionality for content hashing and promoting immutable blocks - with the prefix caching allocator. - - Args: - prev_block (Optional[PrefixCachingBlock]): The previous block in the - sequence. - token_ids (List[int]): The initial token IDs to be stored in the block. - block_size (int): The maximum number of token IDs that can be stored in - the block. - allocator (BlockAllocator): The prefix - caching block allocator associated with this block. - block_id (Optional[int], optional): The physical block index - of this block. Defaults to None. - extra_hash (Optional[int]): The hash value of additional factors - such as adapters that influence the block, apart from the token_ids. - """ - - # Note that we use 'None' as a string here instead of None because - # as of Python 3.12, hash(None) returns a constant predictable value. - # This could possibly make it easier to find and exploit hash - # collisions. 'None' as a string will be hashed differently per process, - # but consistently within the same process. This is the same as the - # behavior of None prior to Python 3.12. - _none_hash: int = hash('None') - - def __init__( - self, - prev_block: Optional[Block], - token_ids: List[int], - block_size: int, - allocator: BlockAllocator, - block_id: Optional[int] = None, - computed: bool = False, - extra_hash: Optional[int] = None, - ): - assert isinstance(allocator, PrefixCachingBlockAllocator), ( - "Currently this class is only tested with " - "PrefixCachingBlockAllocator. Got instead allocator = {}".format( - allocator)) - assert_prefix_caching_block_or_none(prev_block) - - self._prev_block = prev_block - self._cached_content_hash: Optional[int] = None - self._cached_num_tokens_total: int = 0 - self._allocator = allocator - self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME - self._computed = computed - self._extra_hash = extra_hash - - # On the first time, we create the block object, and next we only - # reinitialize it - if hasattr(self, "_block"): - self._block.__init__( # type: ignore[has-type] - prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=self._allocator) - else: - self._block = NaiveBlock(prev_block=prev_block, - token_ids=token_ids, - block_size=block_size, - block_id=block_id, - allocator=self._allocator) - - self._update_num_tokens_total() - - def _update_num_tokens_total(self): - """Incrementally computes the number of tokens that there is - till the current block (included) - """ - res = 0 - - # Add all previous blocks - if self._prev_block is not None: - res += self._prev_block.num_tokens_total - - # Add current block - res += len(self.token_ids) - - self._cached_num_tokens_total = res - - @property - def computed(self) -> bool: - return self._computed - - @computed.setter - def computed(self, value) -> None: - self._computed = value - - @property - def last_accessed(self) -> float: - return self._last_accessed - - @last_accessed.setter - def last_accessed(self, last_accessed_ts: float): - self._last_accessed = last_accessed_ts - - def append_token_ids(self, token_ids: List[int]) -> None: - """Appends the given token IDs to the block and registers the block as - immutable if the block becomes full. - - Args: - token_ids (List[int]): The token IDs to be appended to the block. - """ - # Ensure this is mutable block (not promoted) - assert self.content_hash is None - assert not self.computed - - if len(token_ids) == 0: - return - - # Ensure there are input tokens - assert token_ids, "Got token_ids = {}".format(token_ids) - - # Naive block handles CoW. - self._block.append_token_ids(token_ids) - self._update_num_tokens_total() - - # If the content hash is present, then the block can be made immutable. - # Register ourselves with the allocator, potentially replacing the - # physical block index. - if self.content_hash is not None: - self.block_id = self._allocator.promote_to_immutable_block(self) - - @property - def block_id(self) -> Optional[int]: - return self._block.block_id - - @block_id.setter - def block_id(self, value) -> None: - self._block.block_id = value - - @property - def is_full(self) -> bool: - return self._block.is_full - - @property - def num_empty_slots(self) -> int: - return self._block.num_empty_slots - - @property - def num_tokens_total(self) -> int: - return self._cached_num_tokens_total - - @property - def block_size(self) -> int: - return self._block.block_size - - @property - def token_ids(self) -> List[int]: - return self._block.token_ids - - @property - def prev_block(self) -> Optional[Block]: - return self._prev_block - - @property - def extra_hash(self) -> Optional[int]: - return self._extra_hash - - @property - def content_hash(self) -> Optional[int]: - """Return the content-based hash of the current block, or None if it is - not yet defined. - - For the content-based hash to be defined, the current block must be - full. - """ - # If the hash is already computed, return it. - if self._cached_content_hash is not None: - return self._cached_content_hash - - # We cannot compute a hash for the current block because it is not full. - if not self.is_full: - return None - - is_first_block = self._prev_block is None - prev_block_hash = ( - self._none_hash if is_first_block else - self._prev_block.content_hash # type: ignore - ) - - # Previous block exists but does not yet have a hash. - # Return no hash in this case. - if prev_block_hash == self._none_hash and not is_first_block: - return None - - self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( - is_first_block, - prev_block_hash, - cur_block_token_ids=self.token_ids, - extra_hash=self._extra_hash) - return self._cached_content_hash - - @classmethod - def hash_block_tokens(cls, - is_first_block: bool, - prev_block_hash: Optional[int], - cur_block_token_ids: List[int], - extra_hash: Optional[int] = None) -> int: - """Computes a hash value corresponding to the contents of a block and - the contents of the preceding block(s). The hash value is used for - prefix caching. - - Parameters: - - is_first_block (bool): A flag indicating if the block is the first in - the sequence. - - prev_block_hash (Optional[int]): The hash of the previous block. None - if this is the first block. - - cur_block_token_ids (List[int]): A list of token ids in the current - block. The current block is assumed to be full. - - extra_hash (Optional[int]): The hash value of additional factors - such as adapters that influence the block, apart from the token_ids. - - Returns: - - int: The computed hash value for the block. - """ - if is_first_block and prev_block_hash is None: - prev_block_hash = cls._none_hash - return hash((is_first_block, prev_block_hash, *cur_block_token_ids, - extra_hash)) - - -class ComputedBlocksTracker: - """ - Tracks the computed blocks for each sequence. - - Internally, it maintains a map from sequence id to the list of block hashes - for the sequence. We cache the hashes of the full blocks for each sequence, - and make sure the hash is calculated in the same way as the allocator. - When a sequence is being decoded, we also update the sequence's hash - accordingly and incrementally. - - From the sequence hash, with prefix caching enabled, we could also calculate - the number of cached tokens for the sequence by looking up the number of - cached block hashes in the allocator. - """ - - # Note that we use 'None' as a string here instead of None because - # as of Python 3.12, hash(None) returns a constant predictable value. - # This could possibly make it easier to find and exploit hash - # collisions. 'None' as a string will be hashed differently per process, - # but consistently within the same process. This is the same as the - # behavior of None prior to Python 3.12. - _none_hash: int = hash('None') - - def __init__( - self, - allocator: DeviceAwareBlockAllocator, - block_size: int, - enable_caching: bool, - ): - self._allocator = allocator - self._block_size = block_size - self._enable_caching = enable_caching - - # A map from seq_id to the list of block hashes for the - # sequence. This is so that we don't have to recompute the block hashes - # for the sequence when we need to check if the sequence is cached. - # Note a block that's not full will not have its hash calculated and - # recorded. - self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} - - # A map from seq_id to the number of tokens that are cached for the - # sequence. - # We need this so that a sequence in continuous prefill doesn't - # accidentally see its cached token count change. See comments in - # `get_num_cached_tokens` for more details. - self._seq_id_to_num_tokens_computed: Dict[int, int] = {} - - def _update_seq_hashes(self, seq: Sequence) -> None: - """Incrementally update the sequence's block hashes and record them.""" - assert self._enable_caching - - block_hashes_recorded = self._seq_id_to_blocks_hashes.get( - seq.seq_id, []) - cur_num_blocks_recorded = len(block_hashes_recorded) - token_ids = seq.get_token_ids() - assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( - f"The sequence has {len(token_ids)} tokens, but" - f" already recorded {cur_num_blocks_recorded} blocks. " - "This should not happen since we assume blocks are " - "only appended other than recomputation. When the sequence is " - "recomputed, we should have removed the info of the old blocks.") - # Update the computed block hashes for the sequence. Since only full - # blocks are considered as "computed", we take floor here. - num_computed_blocks = len(token_ids) // self._block_size - - # We need to know the hash of the previous block to compute the hash of - # the current block so that blocks could be uniquely identified across - # sequences of prefixes. - prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else - block_hashes_recorded[-1]) - # Only update the computed block hashes for the new blocks - for i in range(cur_num_blocks_recorded, num_computed_blocks): - assert len(token_ids) >= (i + 1) * self._block_size - block_token_ids = token_ids[i * self._block_size:(i + 1) * - self._block_size] - - # NOTE: If there are any factors affecting the block besides - # token_ids, they should be added as input to extra_hash. - extra_hash = seq.extra_hash() - - # This has to be kept in sync with the allocator's hash - # calculation. - block_hash = PrefixCachingBlock.hash_block_tokens( - is_first_block=prev_block_hash == self._none_hash, - prev_block_hash=prev_block_hash, - cur_block_token_ids=block_token_ids, - extra_hash=extra_hash, - ) - block_hashes_recorded.append(block_hash) - prev_block_hash = block_hash - - self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded - - def get_num_cached_tokens(self, seq: Sequence) -> int: - if not self._enable_caching: - return 0 - - # We always try to update the sequence hashes on the fly. - # This is to ensure that we don't miss any cached tokens for the - # sequence during decode. - # This routine should only update hash for any new blocks too. - self._update_seq_hashes(seq) - - num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( - seq.seq_id, None) - - # TODO(rickyx): This hack could be removed once we mark blocks as - # computed correctly with chunked prefills. - if num_computed_tokens_prev is not None and seq.is_prefill(): - # For a sequence that is still in prefill, we don't - # recompute the number of cached tokens. - # This also handles correctly chunked prefill since currently - # we mark blocks as computed even if the sequence is still partially - # prefilled. So a continuously prefilled sequence should not - # see its cached token count change while running. - return num_computed_tokens_prev - - block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] - - # This is O(logN), where N is the number of blocks. - num_cached_blocks = len( - self._allocator.find_cached_blocks_prefix(block_hashes)) - num_cached_tokens = num_cached_blocks * self._block_size - self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens - return num_cached_tokens - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking the sequence.""" - if not self._enable_caching: - return - assert seq_id in self._seq_id_to_blocks_hashes - del self._seq_id_to_blocks_hashes[seq_id] - - assert seq_id in self._seq_id_to_num_tokens_computed - del self._seq_id_to_num_tokens_computed[seq_id] - - -class LastAccessBlocksTracker: - """Manages the last access time of the tracked sequences, in order to allow - an efficient update of allocator's block last access times - """ - - def __init__(self, allocator): - self._allocator = allocator - self._seq_last_access: Dict[int, Optional[float]] = {} - - def add_seq(self, seq_id: int) -> None: - """Start tracking seq_id - """ - assert seq_id not in self._seq_last_access - self._seq_last_access[seq_id] = None - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking seq_id - """ - assert seq_id in self._seq_last_access - del self._seq_last_access[seq_id] - - def update_last_access(self, seq_id: int, time: float) -> None: - assert seq_id in self._seq_last_access - self._seq_last_access[seq_id] = time - - def update_seq_blocks_last_access(self, seq_id: int, - block_ids: List[int]) -> None: - assert seq_id in self._seq_last_access - - ts = self._seq_last_access[seq_id] - - if ts is None: - # No last access was recorded, no need to update. - return - - self._allocator.mark_blocks_as_accessed(block_ids, ts) - - -def assert_prefix_caching_block_or_none(block: Optional[Block]): - if block is None: - return - assert isinstance(block, - PrefixCachingBlock), "Got block = {}".format(block) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py deleted file mode 100644 index e933c6ee7c8bd..0000000000000 --- a/vllm/core/block/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Block manager utils.""" -from vllm.sequence import SequenceGroup -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) - - -def check_no_caching_or_swa_for_blockmgr_encdec( - block_mgr, seq_group: SequenceGroup) -> None: - ''' - Enforce that prefix caching & sliding-window attention (SWA) - are currently unsupported *specifically* for encoder/decoder models. - - Raises NotImplementedError if unsupported scenario is detected. - - Arguments: - - * block_mgr: BlockSpaceManager instance - * seq_group: SequenceGroup passed to block_mgr - ''' - - if seq_group.is_encoder_decoder(): - if block_mgr.max_block_sliding_window is not None: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - - if block_mgr.enable_caching: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py deleted file mode 100644 index cbfa4d7ff3c4c..0000000000000 --- a/vllm/core/block_manager.py +++ /dev/null @@ -1,523 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A block manager that manages token blocks.""" -from typing import Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple - -from vllm.core.block.block_table import BlockTable -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block -from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - LastAccessBlocksTracker) -from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -SeqId = int -EncoderSeqId = str - - -class SelfAttnBlockSpaceManager(BlockSpaceManager): - """BlockSpaceManager which manages the allocation of KV cache. - - It owns responsibility for allocation, swapping, allocating memory for - autoregressively-generated tokens, and other advanced features such as - prefix caching, forking/copy-on-write, and sliding-window memory allocation. - - This class implements the design described in - https://github.com/vllm-project/vllm/pull/3492. - - Lookahead slots - The block manager has the notion of a "lookahead slot". These are slots - in the KV cache that are allocated for a sequence. Unlike the other - allocated slots, the content of these slots is undefined -- the worker - may use the memory allocations in any way. - - In practice, a worker could use these lookahead slots to run multiple - forward passes for a single scheduler invocation. Each successive - forward pass would write KV activations to the corresponding lookahead - slot. This allows low inter-token latency use-cases, where the overhead - of continuous batching scheduling is amortized over >1 generated tokens. - - Speculative decoding uses lookahead slots to store KV activations of - proposal tokens. - - See https://github.com/vllm-project/vllm/pull/3250 for more information - on lookahead scheduling. - - Args: - block_size (int): The size of each memory block. - num_gpu_blocks (int): The number of memory blocks allocated on GPU. - num_cpu_blocks (int): The number of memory blocks allocated on CPU. - watermark (float, optional): The threshold used for memory swapping. - Defaults to 0.01. - sliding_window (Optional[int], optional): The size of the sliding - window. Defaults to None. - enable_caching (bool, optional): Flag indicating whether caching is - enabled. Defaults to False. - """ - - def __init__( - self, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, - watermark: float = 0.01, - sliding_window: Optional[int] = None, - enable_caching: bool = False, - ) -> None: - self.block_size = block_size - self.num_total_gpu_blocks = num_gpu_blocks - self.num_total_cpu_blocks = num_cpu_blocks - - self.sliding_window = sliding_window - # max_block_sliding_window is the max number of blocks that need to be - # allocated - self.max_block_sliding_window = None - if sliding_window is not None: - # +1 here because // rounds down - num_blocks = sliding_window // block_size + 1 - # +1 here because the last block may not be full, - # and so the sequence stretches one more block at the beginning - # For example, if sliding_window is 3 and block_size is 4, - # we may need 2 blocks when the second block only holds 1 token. - self.max_block_sliding_window = num_blocks + 1 - - self.watermark = watermark - assert watermark >= 0.0 - - self.enable_caching = enable_caching - - self.watermark_blocks = int(watermark * num_gpu_blocks) - - self.block_allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching" if enable_caching else "naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - self.block_tables: Dict[SeqId, BlockTable] = {} - self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} - - self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator, self.block_size, self.enable_caching) - self._last_access_blocks_tracker = LastAccessBlocksTracker( - self.block_allocator) - - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( - seq.get_token_ids(), - block_size=self.block_size, - num_lookahead_slots=num_lookahead_slots, - ) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - num_required_blocks += BlockTable.get_num_required_blocks( - encoder_seq.get_token_ids(), - block_size=self.block_size, - ) - - if self.max_block_sliding_window is not None: - num_required_blocks = min(num_required_blocks, - self.max_block_sliding_window) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - device=Device.GPU) - - # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks - < self.watermark_blocks): - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _allocate_sequence(self, seq: Sequence) -> BlockTable: - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - max_block_sliding_window=self.max_block_sliding_window, - ) - if seq.get_token_ids(): - # NOTE: If there are any factors affecting the block besides - # token_ids, they should be added as input to extra_hash. - extra_hash = seq.extra_hash() - - # Add blocks to the block table only if the sequence is non empty. - block_table.allocate(token_ids=seq.get_token_ids(), - extra_hash=extra_hash) - - return block_table - - def allocate(self, seq_group: SequenceGroup) -> None: - - # Allocate self-attention block tables for decoder sequences - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert not (set(seq.seq_id for seq in waiting_seqs) - & self.block_tables.keys()), "block table already exists" - - # NOTE: Here we assume that all sequences in the group have the same - # prompt. - seq = waiting_seqs[0] - block_table: BlockTable = self._allocate_sequence(seq) - self.block_tables[seq.seq_id] = block_table - - # Track seq - self._last_access_blocks_tracker.add_seq(seq.seq_id) - - # Assign the block table for each sequence. - for seq in waiting_seqs[1:]: - self.block_tables[seq.seq_id] = block_table.fork() - - # Track seq - self._last_access_blocks_tracker.add_seq(seq.seq_id) - - # Allocate cross-attention block table for encoder sequence - # - # NOTE: Here we assume that all sequences in the group have the same - # encoder prompt. - request_id = seq_group.request_id - - assert (request_id - not in self.cross_block_tables), \ - "block table already exists" - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - block_table = self._allocate_sequence(encoder_seq) - self.cross_block_tables[request_id] = block_table - - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - """Determine if there is enough space in the GPU KV cache to continue - generation of the specified sequence group. - - We use a worst-case heuristic: assume each touched block will require a - new allocation (either via CoW or new block). We can append slots if the - number of touched blocks is less than the number of free blocks. - - "Lookahead slots" are slots that are allocated in addition to the slots - for known tokens. The contents of the lookahead slots are not defined. - This is used by speculative decoding when speculating future tokens. - """ - - num_touched_blocks = 0 - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - block_table = self.block_tables[seq.seq_id] - - num_touched_blocks += ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=block_table.get_unseen_token_ids( - seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - )) - - num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( - Device.GPU) - return num_touched_blocks <= num_free_gpu_blocks - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - - block_table = self.block_tables[seq.seq_id] - - block_table.append_token_ids( - token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots, - num_computed_slots=seq.data.get_num_computed_tokens(), - extra_hash=seq.extra_hash(), - ) - # Return any new copy-on-writes. - new_cows = self.block_allocator.clear_copy_on_writes() - return new_cows - - def free(self, seq: Sequence) -> None: - seq_id = seq.seq_id - - if seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - - # Update seq block ids with the latest access time - self._last_access_blocks_tracker.update_seq_blocks_last_access( - seq_id, self.block_tables[seq.seq_id].physical_block_ids) - - # Untrack seq - self._last_access_blocks_tracker.remove_seq(seq_id) - self._computed_blocks_tracker.remove_seq(seq_id) - - # Free table/blocks - self.block_tables[seq_id].free() - del self.block_tables[seq_id] - - def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - seq_id = seq.seq_id - self._computed_blocks_tracker.remove_seq(seq_id) - - def free_cross(self, seq_group: SequenceGroup) -> None: - request_id = seq_group.request_id - if request_id not in self.cross_block_tables: - # Already freed or hasn't been scheduled yet. - return - self.cross_block_tables[request_id].free() - del self.cross_block_tables[request_id] - - def get_block_table(self, seq: Sequence) -> List[int]: - block_ids = self.block_tables[seq.seq_id].physical_block_ids - return block_ids # type: ignore - - def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: - request_id = seq_group.request_id - assert request_id in self.cross_block_tables - block_ids = self.cross_block_tables[request_id].physical_block_ids - assert all(b is not None for b in block_ids) - return block_ids # type: ignore - - def access_all_blocks_in_seq(self, seq: Sequence, now: float): - if self.enable_caching: - # Record the latest access time for the sequence. The actual update - # of the block ids is deferred to the sequence free(..) call, since - # only during freeing of block ids, the blocks are actually added to - # the evictor (which is when the most updated time is required) - # (This avoids expensive calls to mark_blocks_as_accessed(..)) - self._last_access_blocks_tracker.update_last_access( - seq.seq_id, now) - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - # If prefix caching is enabled, mark immutable blocks as computed - # right after they have been scheduled (for prefill). This assumes - # the scheduler is synchronous so blocks are actually computed when - # scheduling the next batch. - self.block_allocator.mark_blocks_as_computed([]) - - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - """Determine which blocks for which we skip prefill. - - With prefix caching we can skip prefill for previously-generated blocks. - Currently, the attention implementation only supports skipping cached - blocks if they are a contiguous prefix of cached blocks. - - This method determines which blocks can be safely skipped for all - sequences in the sequence group. - """ - computed_seq_block_ids = [] - for seq in seqs: - all_blocks = self.block_tables[seq.seq_id].physical_block_ids - num_cached_tokens = ( - self._computed_blocks_tracker.get_num_cached_tokens(seq)) - assert num_cached_tokens % self.block_size == 0 - num_cached_blocks = num_cached_tokens // self.block_size - computed_block_ids = all_blocks[:num_cached_blocks] - computed_seq_block_ids.append(computed_block_ids) - - # NOTE(sang): This assumes seq_block_ids doesn't contain any None. - return self.block_allocator.get_common_computed_block_ids( - computed_seq_block_ids) # type: ignore - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - if parent_seq.seq_id not in self.block_tables: - # Parent sequence has either been freed or never existed. - return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.fork() - - # Track child seq - self._last_access_blocks_tracker.add_seq(child_seq.seq_id) - - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - """Returns the AllocStatus for the given sequence_group - with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap in. - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - AllocStatus: The AllocStatus for the given sequence group. - """ - return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, - num_lookahead_slots) - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - """Returns the block id mapping (from CPU to GPU) generated by - swapping in the given seq_group with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap in. - - Returns: - List[Tuple[int, int]]: The mapping of swapping block from CPU - to GPU. - """ - physical_block_id_mapping = [] - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - blocks = self.block_tables[seq.seq_id].blocks - if len(blocks) == 0: - continue - - seq_swap_mapping = self.block_allocator.swap(blocks=blocks, - src_device=Device.CPU, - dst_device=Device.GPU) - - # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) - - seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id): - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id) - for cpu_block_id, gpu_block_id in seq_swap_mapping.items() - } - - physical_block_id_mapping.extend( - list(seq_physical_block_id_mapping.items())) - - return physical_block_id_mapping - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - """Returns whether we can swap out the given sequence_group - with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap out. - - Returns: - bool: Whether it's possible to swap out current sequence group. - """ - alloc_status = self._can_swap(seq_group, Device.CPU, - SequenceStatus.RUNNING) - return alloc_status == AllocStatus.OK - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - """Returns the block id mapping (from GPU to CPU) generated by - swapping out the given sequence_group with num_lookahead_slots. - - Args: - seq_group (SequenceGroup): The sequence group to swap out. - - Returns: - List[Tuple[int, int]]: The mapping of swapping block from - GPU to CPU. - """ - physical_block_id_mapping = [] - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - blocks = self.block_tables[seq.seq_id].blocks - if len(blocks) == 0: - continue - - seq_swap_mapping = self.block_allocator.swap(blocks=blocks, - src_device=Device.GPU, - dst_device=Device.CPU) - - # Refresh the block ids of the table (post-swap) - self.block_tables[seq.seq_id].update(blocks) - - seq_physical_block_id_mapping = { - self.block_allocator.get_physical_block_id( - Device.GPU, gpu_block_id): - self.block_allocator.get_physical_block_id( - Device.CPU, cpu_block_id) - for gpu_block_id, cpu_block_id in seq_swap_mapping.items() - } - - physical_block_id_mapping.extend( - list(seq_physical_block_id_mapping.items())) - - return physical_block_id_mapping - - def get_num_free_gpu_blocks(self) -> int: - return self.block_allocator.get_num_free_blocks(Device.GPU) - - def get_num_free_cpu_blocks(self) -> int: - return self.block_allocator.get_num_free_blocks(Device.CPU) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_allocator.get_prefix_cache_hit_rate(device) - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - return self.block_allocator.reset_prefix_cache(device) - - def _can_swap(self, - seq_group: SequenceGroup, - device: Device, - status: SequenceStatus, - num_lookahead_slots: int = 0) -> AllocStatus: - """Returns the AllocStatus for swapping in/out the given sequence_group - on to the 'device'. - - Args: - seq_group (SequenceGroup): The sequence group to swap in/out. - device (Device): device to swap the 'seq_group' on. - status (SequenceStatus): The status of sequence which is needed - for action. RUNNING for swap out and SWAPPED for swap in - num_lookahead_slots (int): Number of lookahead slots used in - speculative decoding, default to 0. - - Returns: - AllocStatus: The AllocStatus for swapping in/out the given - sequence_group on to the 'device'. - """ - # First determine the number of blocks that will be touched by this - # swap. Then verify if there are available blocks in the device - # to perform the swap. - num_blocks_touched = 0 - blocks: List[Block] = [] - for seq in seq_group.get_seqs(status=status): - block_table = self.block_tables[seq.seq_id] - if block_table.blocks is not None: - # Compute the number blocks to touch for the tokens to be - # appended. This does NOT include the full blocks that need - # to be touched for the swap. - num_blocks_touched += \ - block_table.get_num_blocks_touched_by_append_slots( - block_table.get_unseen_token_ids(seq.get_token_ids()), - num_lookahead_slots=num_lookahead_slots) - blocks.extend(block_table.blocks) - # Compute the number of full blocks to touch and add it to the - # existing count of blocks to touch. - num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( - blocks, device=device) - - watermark_blocks = 0 - if device == Device.GPU: - watermark_blocks = self.watermark_blocks - - if self.block_allocator.get_num_total_blocks( - device) < num_blocks_touched: - return AllocStatus.NEVER - elif self.block_allocator.get_num_free_blocks( - device) - num_blocks_touched >= watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def get_num_cached_tokens(self, seq: Sequence) -> int: - """Get the number of tokens in blocks that are already computed and - cached in the block manager for the sequence. - """ - return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py deleted file mode 100644 index 85ff6bc9ca610..0000000000000 --- a/vllm/core/evictor.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -import heapq -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple - - -class EvictionPolicy(enum.Enum): - """Enum for eviction policy used by make_evictor to instantiate the correct - Evictor subclass. - """ - LRU = enum.auto() - - -class Evictor(ABC): - """The Evictor subclasses should be used by the BlockAllocator class to - handle eviction of freed Blocks. - """ - - @abstractmethod - def __init__(self): - pass - - @abstractmethod - def __contains__(self, block_id: int) -> bool: - pass - - @abstractmethod - def evict(self) -> Tuple[int, int]: - """Runs the eviction algorithm and returns the evicted block's - content hash along with physical block id along with physical block id - """ - pass - - @abstractmethod - def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, - last_accessed: float): - """Adds block to the evictor, making it a candidate for eviction""" - pass - - @abstractmethod - def update(self, block_id: int, last_accessed: float): - """Update corresponding block's access time in metadata""" - pass - - @abstractmethod - def remove(self, block_id: int): - """Remove a given block id from the cache.""" - pass - - @property - @abstractmethod - def num_blocks(self) -> int: - pass - - -class BlockMetaData: - """Data structure for storing key data describe cached block, so that - evictor could use to make its decision which one to choose for eviction - - Here we use physical block id as the dict key, as there maybe several - blocks with the same content hash, but their physical id is unique. - """ - - def __init__(self, content_hash: int, num_hashed_tokens: int, - last_accessed: float): - self.content_hash = content_hash - self.num_hashed_tokens = num_hashed_tokens - self.last_accessed = last_accessed - - -class LRUEvictor(Evictor): - """Evicts in a least-recently-used order using the last_accessed timestamp - that's recorded in the Block. If there are multiple blocks with - the same last_accessed time, then the one with the largest num_hashed_tokens - will be evicted. If two blocks each have the lowest last_accessed time and - highest num_hashed_tokens value, then one will be chosen arbitrarily - """ - - # CLEANUP_THRESHOLD determines the maximum allowable size of the priority - # queue relative to the free table size. When this threshold is exceeded, - # a cleanup operation is triggered to reduce memory usage. - CLEANUP_THRESHOLD = 50 - - def __init__(self): - self.free_table: Dict[int, BlockMetaData] = {} - self.priority_queue = [] - - def __contains__(self, block_id: int) -> bool: - return block_id in self.free_table - - def evict(self) -> Tuple[int, int]: - if len(self.free_table) == 0: - raise ValueError("No usable cache memory left") - - while self.priority_queue: - # We do not remove outdated entries from the priority queue at the - # time of updating the last_accessed timestamp. Instead, outdated - # entries are filtered out here during eviction. Outdated entries - # would either not in the free table, or have older last accessed - # time. - last_accessed, _, block_id, content_hash = heapq.heappop( - self.priority_queue) - if (block_id in self.free_table and - self.free_table[block_id].last_accessed == last_accessed): - self.free_table.pop(block_id) - return block_id, content_hash - - raise ValueError("No usable cache memory left") - - def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, - last_accessed: float): - self.free_table[block_id] = BlockMetaData(content_hash, - num_hashed_tokens, - last_accessed) - heapq.heappush( - self.priority_queue, - (last_accessed, -num_hashed_tokens, block_id, content_hash)) - self._cleanup_if_necessary() - - def update(self, block_id: int, last_accessed: float): - self.free_table[block_id].last_accessed = last_accessed - - def _cleanup_if_necessary(self): - if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len( - self.free_table): - self._cleanup() - - def _cleanup(self): - new_priority_queue: List[Tuple[float, int, int, int]] = [] - - for block_id, block in self.free_table.items(): - new_priority_queue.append( - (block.last_accessed, -block.num_hashed_tokens, block_id, - block.content_hash)) - heapq.heapify(new_priority_queue) - - self.priority_queue = new_priority_queue - - def remove(self, block_id: int): - if block_id not in self.free_table: - raise ValueError( - "Attempting to remove block that's not in the evictor") - self.free_table.pop(block_id) - - @property - def num_blocks(self) -> int: - return len(self.free_table) - - -def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: - if eviction_policy == EvictionPolicy.LRU: - return LRUEvictor() - else: - raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py deleted file mode 100644 index 69b9169ddd8a9..0000000000000 --- a/vllm/core/interfaces.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -from abc import ABC, abstractmethod -from typing import List, Optional -from typing import Sequence as GenericSequence -from typing import Tuple - -from vllm.sequence import Sequence, SequenceGroup -from vllm.utils import Device - - -class AllocStatus(enum.Enum): - """Result for BlockSpaceManager.can_allocate - - 1. Ok: seq_group can be allocated now. - 2. Later: seq_group cannot be allocated. - The capacity of allocator is larger than seq_group required. - 3. Never: seq_group can never be allocated. - The seq_group is too large to allocated in GPU. - """ - OK = enum.auto() - LATER = enum.auto() - NEVER = enum.auto() - - -class BlockSpaceManager(ABC): - - @staticmethod - def get_block_space_manager_class(version: str): - version = version.lower() - - if version == "selfattn": - from vllm.core.block_manager import SelfAttnBlockSpaceManager - return SelfAttnBlockSpaceManager - - if version == "placeholder": - from vllm.core.placeholder_block_space_manager import ( - PlaceholderBlockSpaceManager) - return PlaceholderBlockSpaceManager - - raise ValueError(f"Unknown version {version=}") - - @abstractmethod - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - pass - - @abstractmethod - def allocate(self, seq_group: SequenceGroup) -> None: - pass - - @abstractmethod - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - pass - - @abstractmethod - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - pass - - @abstractmethod - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - pass - - @abstractmethod - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - pass - - @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - pass - - @abstractmethod - def free(self, seq: Sequence) -> None: - pass - - @abstractmethod - def get_block_table(self, seq: Sequence) -> List[int]: - pass - - @abstractmethod - def get_num_free_gpu_blocks(self) -> int: - pass - - @abstractmethod - def get_num_free_cpu_blocks(self) -> int: - pass - - @abstractmethod - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - pass - - @abstractmethod - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - pass - - @abstractmethod - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self, device: Device) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - @abstractmethod - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache for specified or all devices.""" - pass - - @abstractmethod - def get_num_cached_tokens(self, seq: Sequence) -> int: - pass - - @abstractmethod - def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - pass \ No newline at end of file diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py deleted file mode 100644 index 679515924e85d..0000000000000 --- a/vllm/core/placeholder_block_space_manager.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Tuple - -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.sequence import Sequence, SequenceGroup -from vllm.utils import Device - - -class PlaceholderBlockSpaceManager(BlockSpaceManager): - """A version of BlockSpaceManager for use in environments - where block management is not required. - For example: pooling models or attention-free models like Mamba. - - This class provides the same interface as BlockSpaceManager, but its - methods perform no actions or return simple values like True in specific - actions. It's designed to be used in scenarios where the overhead of - block management is unnecessary, such as in an embedding environment. - """ - - def __init__( - self, - **kwargs, - ) -> None: - pass - - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - # Always return OK for dummy purposes - return AllocStatus.OK - - def allocate(self, seq_group: SequenceGroup) -> None: - # No actual allocation logic needed - pass - - def can_append_slots(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return True - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int, - ) -> List[Tuple[int, int]]: - return [] - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - pass - - def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> AllocStatus: - return AllocStatus.OK - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - return None # type: ignore - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - return True - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - return None # type: ignore - - def free(self, seq: Sequence) -> None: - # No operation on free - return - - def get_block_table(self, seq: Sequence) -> List[int]: - return None # type: ignore - - def get_num_free_gpu_blocks(self) -> int: - return 1 - - def get_num_free_cpu_blocks(self) -> int: - return 1 - - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - pass - - def get_common_computed_block_ids(self, - seq_group: List[Sequence]) -> List[int]: - return [] - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - pass - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return -1 - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - return True - - def get_num_cached_tokens(self, seq: Sequence) -> int: - return 0 - - def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - return diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py deleted file mode 100644 index 92ebad778ea4b..0000000000000 --- a/vllm/core/scheduler.py +++ /dev/null @@ -1,2028 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -import os -import random -import time -from collections import deque -from dataclasses import dataclass, field -from typing import Callable, Deque, Dict, Iterable, List, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union - -from vllm.config import CacheConfig, SchedulerConfig -from vllm.config.lora import LoRAConfig -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupBase, SequenceGroupMetadata, - SequenceGroupMetadataDelta, SequenceStage, - SequenceStatus) -from vllm.utils import Device, PyObjectCache - -logger = init_logger(__name__) - -# Test-only. If configured, decode is preempted with -# ARTIFICIAL_PREEMPTION_PROB% probability. -ENABLE_ARTIFICIAL_PREEMPT = bool( - os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa -ARTIFICIAL_PREEMPTION_PROB = 0.5 -ARTIFICIAL_PREEMPTION_MAX_CNT = 500 - - -class PreemptionMode(enum.Enum): - """Preemption modes. - - 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory - and swap them back in when the sequences are resumed. - 2. Recomputation: Discard the blocks of the preempted sequences and - recompute them when the sequences are resumed, treating the sequences as - new prompts. - """ - - SWAP = enum.auto() - RECOMPUTE = enum.auto() - - -@dataclass -class SchedulingBudget: - """The available slots for scheduling. - - TODO(sang): Right now, the budget is request_id-aware meaning it can ignore - budget update from the same request_id. It is because in normal scheduling - path, we update RUNNING num_seqs ahead of time, meaning it could be - updated more than once when scheduling RUNNING requests. Since this won't - happen if we only have chunked prefill scheduling, we can remove this - feature from the API when chunked prefill is enabled by default. - """ - - token_budget: int - max_num_seqs: int - _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) - _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) - # Number of cached tokens in the batch. - _num_cached_tokens: int = 0 - # Number of actual non-cached tokens in the batch. - _num_batched_tokens: int = 0 - _num_curr_seqs: int = 0 - - def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - # We allow num_new_tokens to be 0 when the entire sequence has - # been cached. - assert num_new_tokens >= 0 - assert num_new_seqs != 0 - return (self.num_batched_tokens + num_new_tokens <= self.token_budget - and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) - - def remaining_token_budget(self): - return self.token_budget - self.num_batched_tokens - - def add_num_batched_tokens(self, - req_id: str, - num_batched_tokens: int, - num_cached_tokens: int = 0): - if req_id in self._request_ids_num_batched_tokens: - return - assert num_cached_tokens >= 0 - assert num_batched_tokens >= 0 - - self._request_ids_num_batched_tokens.add(req_id) - self._num_batched_tokens += num_batched_tokens - self._num_cached_tokens += num_cached_tokens - - def subtract_num_batched_tokens(self, req_id: str, - num_batched_tokens: int): - if req_id in self._request_ids_num_batched_tokens: - self._request_ids_num_batched_tokens.remove(req_id) - self._num_batched_tokens -= num_batched_tokens - - def add_num_seqs(self, req_id: str, num_curr_seqs: int): - if req_id in self._request_ids_num_curr_seqs: - return - - self._request_ids_num_curr_seqs.add(req_id) - self._num_curr_seqs += num_curr_seqs - - def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): - if req_id in self._request_ids_num_curr_seqs: - self._request_ids_num_curr_seqs.remove(req_id) - self._num_curr_seqs -= num_curr_seqs - - @property - def num_batched_tokens(self): - return self._num_batched_tokens - - @property - def num_curr_seqs(self): - return self._num_curr_seqs - - @property - def num_cached_tokens(self): - return self._num_cached_tokens - - -@dataclass -class ScheduledSequenceGroup: - # A sequence group that's scheduled. - seq_group: SequenceGroup - # The total chunk size (number of tokens) to process for next iteration. - # 1 for decoding. Same as prompt tokens for prefill, but if prefill is - # chunked, it can be smaller than that. - token_chunk_size: int - - -@dataclass -class SchedulerOutputs: - """The scheduling decision made from a scheduler.""" - - # Scheduled sequence groups. - scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] - # Number of prefill groups scheduled. - num_prefill_groups: int - # Total number of batched tokens. - num_batched_tokens: int - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] - # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] - # Sequence groups that are going to be ignored. - ignored_seq_groups: List[SequenceGroup] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - # The number of requests in the running queue - running_queue_size: int - preempted: int - - def __post_init__(self): - # Swap in and swap out should never happen at the same time. - assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) - - self.num_loras: int = len(self.lora_requests) - if self.num_loras > 0: - self._sort_by_lora_ids() - - def is_empty(self) -> bool: - # NOTE: We do not consider the ignored sequence groups. - return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) - - def _sort_by_lora_ids(self): - assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups) - - def key_fn(group: ScheduledSequenceGroup): - key = (group.seq_group.lora_int_id, group.seq_group.request_id) - if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups): - # Sort sequence groups so that all prefills come before all - # decodes as required by chunked prefill. - return (not group.seq_group.is_prefill(), *key) - return key - - self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, - key=key_fn) - - @property - def lora_requests(self) -> Set[LoRARequest]: - return { - g.seq_group.lora_request - for g in self.scheduled_seq_groups - if g.seq_group.lora_request is not None - } - - -@dataclass -class SchedulerRunningOutputs: - """The requests that are scheduled from a running queue. - - Could contain prefill (prefill that's chunked) or decodes. If there's not - enough memory, it can be preempted (for recompute) or swapped out. - """ - - # Selected sequences that are running and in a decoding phase. - decode_seq_groups: List[ScheduledSequenceGroup] - # Selected sequences that are running and in a prefill phase. - # I.e., it means the prefill has been chunked. - prefill_seq_groups: List[ScheduledSequenceGroup] - # The preempted sequences. - preempted: List[SequenceGroup] - # Sequences that are swapped out. - swapped_out: List[SequenceGroup] - # The blocks to swap out. - blocks_to_swap_out: List[Tuple[int, int]] - # The blocks to copy. - blocks_to_copy: List[Tuple[int, int]] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - - # Optimization for fast-access to seq_group lists - decode_seq_groups_list: List[SequenceGroup] - prefill_seq_groups_list: List[SequenceGroup] - - @classmethod - def create_empty(cls) -> "SchedulerRunningOutputs": - return SchedulerRunningOutputs( - decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - decode_seq_groups_list=[], - prefill_seq_groups_list=[], - ) - - -@dataclass -class SchedulerSwappedInOutputs: - """The requests that are scheduled from a swap queue. - - Could contain prefill (prefill that's chunked) or decodes. - """ - - # Selected sequences that are going to be swapped in and is in a - # decoding phase. - decode_seq_groups: List[ScheduledSequenceGroup] - # Selected sequences that are going to be swapped in and in a prefill - # phase. I.e., it means the prefill has been chunked. - prefill_seq_groups: List[ScheduledSequenceGroup] - # The blocks to swap in. - blocks_to_swap_in: List[Tuple[int, int]] - # The blocks to copy. - blocks_to_copy: List[Tuple[int, int]] - # The number of slots for lookahead decoding. - num_lookahead_slots: int - # Infeasible sequence groups. - infeasible_seq_groups: List[SequenceGroup] - - @classmethod - def create_empty(cls) -> "SchedulerSwappedInOutputs": - return SchedulerSwappedInOutputs( - decode_seq_groups=[], - prefill_seq_groups=[], - blocks_to_swap_in=[], - blocks_to_copy=[], - num_lookahead_slots=0, - infeasible_seq_groups=[], - ) - - -@dataclass -class SchedulerPrefillOutputs: - """The requests that are scheduled from a waiting queue. - - Could contain a fresh prefill requests or preempted requests that need - to be recomputed from scratch. - """ - - # Selected sequences for prefill. - seq_groups: List[ScheduledSequenceGroup] - # Ignored sequence groups. - ignored_seq_groups: List[SequenceGroup] - num_lookahead_slots: int - - @classmethod - def create_empty(cls) -> "SchedulerPrefillOutputs": - return SchedulerPrefillOutputs( - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=0, - ) - - -def seq_group_metadata_builder(): - return SequenceGroupMetadata(request_id="", - is_prompt=False, - seq_data={}, - sampling_params=None, - block_tables={}) - - -def scheduler_running_outputs_builder(): - return SchedulerRunningOutputs(decode_seq_groups=[], - prefill_seq_groups=[], - preempted=[], - swapped_out=[], - blocks_to_swap_out=[], - blocks_to_copy=[], - num_lookahead_slots=0, - prefill_seq_groups_list=[], - decode_seq_groups_list=[]) - - -def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), - token_chunk_size=0) - # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) - - -@dataclass -class PartialPrefillMetadata: - """Holds information about the partial prefills that are currently running - during a single iteration of the Scheduler. - When chunked prefill is enabled, we allow a certain number of seqs to be - partially prefilled during each iteration. Having multiple partial prefills - in flight allows us to minimize TTFT and avoid decode starvation in cases - where a single sequence group with a very large prompt blocks the queue for - too many iterations. - The number of long prefill requests is limited so that smaller - requests may jump the queue in front of them and get to the decode - phase faster. - """ - - # A minimum bound on the total number of prefills to be scheduled during - # this iteration - schedulable_prefills: int - - # The number of long prefill requests currently running - long_prefills: int - - scheduler_config: SchedulerConfig - - def can_schedule(self, seq_group: SequenceGroup) -> bool: - """When concurrent partial prefills are enabled, - we limit the number of long requests and only accept - shorter requests from the queue while running them - concurrently""" - return not (seq_group.first_seq.get_num_new_tokens() - > self.scheduler_config.long_prefill_token_threshold - and self.long_prefills - >= self.scheduler_config.max_long_partial_prefills - and self.scheduler_config.max_num_partial_prefills > 1) - - def maybe_increment_partial_prefills(self, - seq_group: SequenceGroup) -> None: - # When a new prefill is scheduled, we need to know if it is a - # long request - if (seq_group.first_seq.get_num_new_tokens() - > self.scheduler_config.long_prefill_token_threshold): - self.long_prefills += 1 - - @classmethod - def from_queues( - cls, - running: Deque[SequenceGroup], - waiting: Deque[SequenceGroup], - scheduler_config: SchedulerConfig, - ) -> "PartialPrefillMetadata": - """Create a PartialPrefillMetadata object from the current state of - the scheduler's queues. - This accounts for the currently running prefill requests, and peeks into - the waiting queue to see if there are more prefills to potentially be - scheduled during this iteration.""" - prefills = 0 - long_prefills = 0 - - waiting_long_prefills = 0 - - for sg in running: - if sg.first_seq.data.stage == SequenceStage.PREFILL: - prefills += 1 - if (sg.first_seq.get_num_new_tokens() - > scheduler_config.long_prefill_token_threshold): - long_prefills += 1 - - for sg in waiting: - # Don't bother looping through the rest of the queue if we know - # there are already at - # least max_partial_prefills requests to fill - if prefills >= scheduler_config.max_num_partial_prefills: - break - - # Don't count long requests from the waiting queue if we aren't - # going to schedule them anyway - if (sg.first_seq.get_num_new_tokens() - > scheduler_config.long_prefill_token_threshold): - if (long_prefills + waiting_long_prefills - >= scheduler_config.max_long_partial_prefills): - continue - waiting_long_prefills += 1 - prefills += 1 - - # NB: long_prefills and waiting_long_prefills are tracked separately. - # We don't account for the waiting requests here because we need to use - # this metadata to track how many have actually been scheduled. - return PartialPrefillMetadata( - schedulable_prefills=min( - prefills, scheduler_config.max_num_partial_prefills), - long_prefills=long_prefills, - scheduler_config=scheduler_config, - ) - - -class Scheduler: - - def __init__( - self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - output_proc_callback: Optional[Callable] = None, - ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - # Note for LoRA scheduling: the current policy is extremely - # simple and NOT fair. It can lead to starvation of some - # LoRAs. This should be improved in the future. - self.lora_config = lora_config - - version = "selfattn" - if (self.scheduler_config.runner_type == "pooling" - or self.cache_config.is_attention_free): - version = "placeholder" - - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version) - - num_gpu_blocks = cache_config.num_gpu_blocks - if num_gpu_blocks: - num_gpu_blocks //= pipeline_parallel_size - - num_cpu_blocks = cache_config.num_cpu_blocks - if num_cpu_blocks: - num_cpu_blocks //= pipeline_parallel_size - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching, - ) - - # Sequence groups in the WAITING state. - # Contain new prefill or preempted requests. - self.waiting: Deque[SequenceGroup] = deque() - # Sequence groups in the RUNNING state. - # Contain decode requests. - self.running: Deque[SequenceGroup] = deque() - # Sequence groups in the SWAPPED state. - # Contain decode requests that are swapped out. - self.swapped: Deque[SequenceGroup] = deque() - # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests - # can and must be released after the current step. - # This is used to evict the finished requests from the Mamba cache. - self._finished_requests_ids: List[str] = list() - # Time at previous scheduling step - self.prev_time = 0.0 - # Did we schedule a prompt at previous step? - self.prev_prompt = False - # Latency of the last prompt step - self.last_prompt_latency = 0.0 - # preemption mode, RECOMPUTE or SWAP - self.user_specified_preemption_mode = scheduler_config.preemption_mode - - # The following field is test-only. It is used to inject artificial - # preemption. - self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT - self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT - if self.enable_artificial_preemption - else 0) - self.num_cumulative_preemption: int = 0 - - # Used to cache python objects - self._seq_group_metadata_cache: List[PyObjectCache] = [] - self._scheduler_running_outputs_cache: List[PyObjectCache] = [] - self._scheduled_seq_group_cache: List[PyObjectCache] = [] - - # For async output processing, we need to swap cache buffers between - # iterations. I.e. since the output processing is lagged one step, - # we cannot reuse the cached objects immediately when the schedule() - # is called again, but only when schedule() is called the second time. - self.output_proc_callback = output_proc_callback - self.use_async_output_proc = self.output_proc_callback is not None - self.num_cache_iters = 2 if self.use_async_output_proc else 1 - - self.cache_id = 0 - for i in range(self.num_cache_iters): - self._seq_group_metadata_cache.append( - PyObjectCache(seq_group_metadata_builder)) - self._scheduler_running_outputs_cache.append( - PyObjectCache(scheduler_running_outputs_builder)) - self._scheduled_seq_group_cache.append( - PyObjectCache(scheduled_seq_group_builder)) - - # For async postprocessor, the extra decode run cannot be done - # when the request reaches max_model_len. In this case, the request - # will be stopped during schedule() call and added to this stop list - # for processing and deallocation by the free_finished_seq_groups() - self._async_stopped: List[SequenceGroup] = [] - - # List with the chunk sizes to hand out to each sequence depending - # on how many partial prefills are running. This is slightly faster than - # running an integer division every time a prefill is scheduled. - # This splits the budget evenly among all prefills. - self.partial_prefill_budget_lookup_list = [0] * ( - self.scheduler_config.max_num_partial_prefills + 1) - self.partial_prefill_budget_lookup_list[0] = ( - scheduler_config.max_num_batched_tokens) - for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): - self.partial_prefill_budget_lookup_list[i] = ( - scheduler_config.max_num_batched_tokens // i) - - @property - def next_cache_id(self): - return (self.cache_id + 1) % self.num_cache_iters - - @property - def lora_enabled(self) -> bool: - return bool(self.lora_config) - - @property - def num_decoding_tokens_per_seq(self) -> int: - """The number of new tokens.""" - return 1 - - def add_seq_group(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the waiting queue. - self.waiting.append(seq_group) - - def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the running queue. - # Only for testing purposes. - self.running.append(seq_group) - - def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the swapped queue. - # Only for testing purposes. - self.swapped.append(seq_group) - - def abort_seq_group( - self, - request_id: Union[str, Iterable[str]], - seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, - ) -> None: - """Aborts a sequence group with the given ID. - - Check if the sequence group with the given ID - is present in any of the state queue. - If present, remove the sequence group from the state queue. - Also, if any of the sequences in the sequence group is not finished, - free the sequence with status `FINISHED_ABORTED`. - Otherwise, do nothing. - - Args: - request_id: The ID(s) of the sequence group to abort. - seq_id_to_seq_group: helper for groups with n>1 - """ - if isinstance(request_id, str): - request_id = (request_id, ) - request_ids = set(request_id) - seq_id_to_seq_group = seq_id_to_seq_group or {} - for state_queue in [self.waiting, self.running, self.swapped]: - aborted_groups: List[SequenceGroup] = [] - for seq_group in state_queue: - # When n>1, seq_group.request_id looks like - # foo_parallel_sample_0, while request_ids is just foo, and we - # should resolve it as real_request_id to match. - if seq_group.request_id in seq_id_to_seq_group: - real_request_id = seq_id_to_seq_group[ - seq_group.request_id].group_id - else: - real_request_id = seq_group.request_id - if real_request_id in request_ids: - # Appending aborted group into pending list. - aborted_groups.append(seq_group) - # We can't remove real_request_id in request_ids here, - # because there may be other seq groups sharing the same - # real_request_id - for aborted_group in aborted_groups: - # Remove the sequence group from the state queue. - state_queue.remove(aborted_group) - # Remove the aborted request from the Mamba cache. - self._finished_requests_ids.append(aborted_group.request_id) - for seq in aborted_group.get_seqs(): - if seq.is_finished(): - continue - seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) - if aborted_group.request_id in seq_id_to_seq_group: - del seq_id_to_seq_group[aborted_group.request_id] - - self._free_seq_group_cross_attn_blocks(aborted_group) - - def _free_seq_group_cross_attn_blocks( - self, - seq_group: SequenceGroup, - ) -> None: - """ - Free a sequence group from a cross-attention block table. - Has no effect on decoder-only models. - """ - if seq_group.is_encoder_decoder(): - self.block_manager.free_cross(seq_group) - - def has_unfinished_seqs(self) -> bool: - return (len(self.waiting) != 0 or len(self.running) != 0 - or len(self.swapped) != 0) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_manager.get_prefix_cache_hit_rate(device) - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - return self.block_manager.reset_prefix_cache(device) - - def get_num_unfinished_seq_groups(self) -> int: - return len(self.waiting) + len(self.running) + len(self.swapped) - - def get_and_reset_finished_requests_ids(self) -> List[str]: - """Flushes the list of request ids of previously finished seq_groups.""" - finished_requests_ids = self._finished_requests_ids - self._finished_requests_ids = list() - return finished_requests_ids - - def _schedule_running( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> SchedulerRunningOutputs: - """Schedule sequence groups that are running. - - Running queue should include decode and chunked prefill requests. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any decodes are preempted. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any decodes are preempted. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - partial_prefill_metadata: information about the partial prefills - that are currently running - - Returns: - SchedulerRunningOutputs. - """ - ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ - self.cache_id].get_object() - ret.blocks_to_swap_out.clear() - ret.blocks_to_copy.clear() - ret.decode_seq_groups.clear() - ret.prefill_seq_groups.clear() - ret.preempted.clear() - ret.swapped_out.clear() - - ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking) - - ret.decode_seq_groups_list.clear() - ret.prefill_seq_groups_list.clear() - - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out - blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy - - decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups - prefill_seq_groups: List[ - ScheduledSequenceGroup] = ret.prefill_seq_groups - preempted: List[SequenceGroup] = ret.preempted - swapped_out: List[SequenceGroup] = ret.swapped_out - - running_queue = self.running - assert len(self._async_stopped) == 0 - while running_queue: - seq_group = running_queue[0] - # We discard the cached tokens info here because we don't need it - # for running sequence: - # 1. If a sequence is running with chunked prefill, the cached - # tokens info was already used for the first prefill. - # 2. If a sequence is running with non-chunked prefill, then - # there it's a decoding sequence, and the cached tokens info is - # irrelevant. - num_uncached_new_tokens, _ = \ - self._get_num_new_uncached_and_cached_tokens( - seq_group, - SequenceStatus.RUNNING, - enable_chunking, - budget, - partial_prefill_metadata, - ) - - num_running_tokens = num_uncached_new_tokens - if num_running_tokens == 0: - # No budget => Stop - break - - running_queue.popleft() - - # With async postprocessor, an extra decode run is done - # to process the final tokens. The check below avoids this extra - # decode run when the model max len is reached, in order to avoid - # a memory overflow. - if (self.use_async_output_proc and seq_group.seqs[0].get_len() - > self.scheduler_config.max_model_len): - self._async_stopped.append(seq_group) - continue - - # NOTE(woosuk): Preemption happens only when there is no available - # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group, enable_chunking): - budget.subtract_num_batched_tokens(seq_group.request_id, - num_running_tokens) - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(seq_group.request_id, - num_running_seqs) - - if (curr_loras is not None and seq_group.lora_int_id > 0 - and seq_group.lora_int_id in curr_loras): - curr_loras.remove(seq_group.lora_int_id) - - # Determine victim sequence - cont_loop = True - if running_queue: - # Preempt the lowest-priority sequence group. - victim_seq_group = running_queue.pop() - else: - # No other sequence group can be preempted. - # Preempt the current sequence group. - # Note: This is also where we stop this loop - # (since there is nothing else to preempt) - victim_seq_group = seq_group - cont_loop = False - - # With async postprocessor, before preempting a sequence - # we need to ensure it has no pending async postprocessor - do_preempt = True - if self.use_async_output_proc: - assert self.output_proc_callback is not None - self.output_proc_callback( - request_id=victim_seq_group.request_id) - - # It may be that the async pending "victim_seq_group" - # becomes finished, in which case we simply free it. - if victim_seq_group.is_finished(): - self._free_finished_seq_group(victim_seq_group) - do_preempt = False - - # Do preemption - if do_preempt: - preempted_mode = self._preempt(victim_seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(victim_seq_group) - else: - swapped_out.append(victim_seq_group) - - if not cont_loop: - break - else: - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - is_prefill = seq_group.is_prefill() - - scheduled_seq_group: ScheduledSequenceGroup = ( - self._scheduled_seq_group_cache[ - self.cache_id].get_object()) - scheduled_seq_group.seq_group = seq_group - if is_prefill: - scheduled_seq_group.token_chunk_size = num_running_tokens - prefill_seq_groups.append(scheduled_seq_group) - ret.prefill_seq_groups_list.append(seq_group) - else: - scheduled_seq_group.token_chunk_size = 1 - decode_seq_groups.append(scheduled_seq_group) - ret.decode_seq_groups_list.append(seq_group) - - budget.add_num_batched_tokens(seq_group.request_id, - num_running_tokens) - # OPTIMIZATION: Note that get_max_num_running_seqs is - # expensive. For the default scheduling chase where - # enable_chunking is False, num_seqs are updated before running - # this method, so we don't have to update it again here. - if enable_chunking: - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.add_num_seqs(seq_group.request_id, num_running_seqs) - if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.add(seq_group.lora_int_id) - - self._scheduler_running_outputs_cache[self.next_cache_id].reset() - self._scheduled_seq_group_cache[self.next_cache_id].reset() - - return ret - - def _schedule_swapped( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - ) -> SchedulerSwappedInOutputs: - """Schedule sequence groups that are swapped out. - - It schedules swapped requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are swapped in. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are swapped in. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - - Returns: - SchedulerSwappedInOutputs. - """ - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: List[Tuple[int, int]] = [] - blocks_to_copy: List[Tuple[int, int]] = [] - decode_seq_groups: List[ScheduledSequenceGroup] = [] - prefill_seq_groups: List[ScheduledSequenceGroup] = [] - infeasible_seq_groups: List[SequenceGroup] = [] - - swapped_queue = self.swapped - - leftover_swapped: Deque[SequenceGroup] = deque() - while swapped_queue: - seq_group = swapped_queue[0] - - # If the sequence group cannot be swapped in, stop. - is_prefill = seq_group.is_prefill() - alloc_status = self.block_manager.can_swap_in( - seq_group, - self._get_num_lookahead_slots(is_prefill, enable_chunking)) - if alloc_status == AllocStatus.LATER: - break - elif alloc_status == AllocStatus.NEVER: - logger.warning( - "Failing the request %s because there's not enough kv " - "cache blocks to run the entire sequence.", - seq_group.request_id, - ) - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - infeasible_seq_groups.append(seq_group) - swapped_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (lora_int_id > 0 and (lora_int_id not in curr_loras) - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_swapped.appendleft(seq_group) - swapped_queue.popleft() - continue - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.SWAPPED, enable_chunking, - budget)) - - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.SWAPPED) - break - - if lora_int_id > 0 and curr_loras is not None: - curr_loras.add(lora_int_id) - swapped_queue.popleft() - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - if is_prefill: - prefill_seq_groups.append( - ScheduledSequenceGroup( - seq_group, - token_chunk_size=num_new_tokens_uncached + - num_new_tokens_cached, - )) - else: - decode_seq_groups.append( - ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens( - seq_group.request_id, - num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens_cached, - ) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - swapped_queue.extendleft(leftover_swapped) - - return SchedulerSwappedInOutputs( - decode_seq_groups=decode_seq_groups, - prefill_seq_groups=prefill_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking), - infeasible_seq_groups=infeasible_seq_groups, - ) - - def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled: - prompt_limit = self.scheduler_config.max_model_len - else: - prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens, - ) - - # Model is fine tuned with long context. Return the fine tuned max_len. - if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: - assert prompt_limit <= seq_group.lora_request.long_lora_max_len - return seq_group.lora_request.long_lora_max_len - else: - return prompt_limit - - def _get_priority(self, - seq_group: SequenceGroup) -> Tuple[Optional[int], float]: - """Get the priority of the sequence group. - Highest preference to user-defined priority, followed by arrival time. - Args: - seq_group: The sequence group input. - Returns: - The priority of the sequence group. - """ - return seq_group.priority, seq_group.arrival_time - - def _schedule_priority_preemption( - self, - budget: SchedulingBudget, - ) -> int: - """Sorts waiting and running queue. Also, force preempt requests - from the running queue if their priority is lower. - Priority-based preemption is used with the priority policy. - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - Returns: - A count of priority-based preemptions. - """ - - waiting_queue = self.waiting - - running_queue = deque(sorted(self.running, key=self._get_priority)) - - blocks_to_swap_out: List[Tuple[int, int]] = [] - force_preemption_count = 0 - - if waiting_queue: - seq_group = waiting_queue.popleft() - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, _ = \ - self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, False, budget) - - # Only preempt if priority inversion exists - while running_queue and self._get_priority( - running_queue[-1]) > self._get_priority(seq_group): - # Only preempt if waiting sequence cannot be allocated - can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens_uncached > 0 - and can_allocate == AllocStatus.OK - and budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - )): - break - - # Adjust budget to remove the victim sequence group - vseq_group = running_queue.pop() - num_running_tokens_uncached, _ = ( - self._get_num_new_uncached_and_cached_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget)) - budget.subtract_num_batched_tokens( - vseq_group.request_id, num_running_tokens_uncached) - num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(vseq_group.request_id, - num_running_seqs) - - # Preempt out the victim sequence group - self._preempt(vseq_group, blocks_to_swap_out) - waiting_queue.appendleft(vseq_group) - force_preemption_count += 1 - # Put the sequence back into the waiting queue - waiting_queue.appendleft(seq_group) - - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - - waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) - - self.waiting = waiting_queue - self.running = running_queue - return force_preemption_count - - def _schedule_prefills( - self, - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> SchedulerPrefillOutputs: - """Schedule sequence groups that are in prefill stage. - - Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE - as a new prefill (that starts from beginning -> most recently generated - tokens). - - It schedules waiting requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are scheduled. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - partial_prefill_metadata: information about the partial prefills - that are currently running - - Returns: - SchedulerPrefillOutputs. - """ - if budget.remaining_token_budget() == 0: - # Do nothing: Can't add any more prefill anyway - return SchedulerPrefillOutputs( - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), - ) - ignored_seq_groups: List[SequenceGroup] = [] - seq_groups: List[ScheduledSequenceGroup] = [] - using_prompt_embeds: bool = False - - waiting_queue = self.waiting - - leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: - seq_group = waiting_queue[0] - - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - if (partial_prefill_metadata is not None - and not partial_prefill_metadata.can_schedule(seq_group)): - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_uncached_and_cached_tokens( - seq_group, - SequenceStatus.WAITING, - enable_chunking, - budget, - partial_prefill_metadata=partial_prefill_metadata, - )) - num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - - if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() - assert num_new_tokens == num_prompt_tokens - - prompt_limit = self._get_prompt_limit(seq_group) - if num_new_tokens > prompt_limit: - logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", - num_new_tokens, - prompt_limit, - ) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.FINISHED_IGNORED) - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - num_lookahead_slots: int = 0 - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate( - seq_group, num_lookahead_slots=num_lookahead_slots) - if can_allocate == AllocStatus.LATER: - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - "Input prompt (%d tokens) + lookahead slots (%d) is " - "too long and exceeds the capacity of block_manager", - num_new_tokens, - num_lookahead_slots, - ) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.FINISHED_IGNORED) - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - # We cannot mix sequence groups that use prompt embeds and - # those that do not. - if len(seq_groups) == 0: - using_prompt_embeds = seq_group.uses_prompt_embeds() - if using_prompt_embeds != seq_group.uses_prompt_embeds(): - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (self.lora_enabled and lora_int_id > 0 - and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - - if (budget.num_batched_tokens - >= self.scheduler_config.max_num_batched_tokens): - # We've reached the budget limit - since there might be - # continuous prefills in the running queue, we should break - # to avoid scheduling any new prefills. - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - break - - num_new_seqs = seq_group.get_max_num_running_seqs() - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - self.remove_seq_from_computed_blocks_tracker( - seq_group, SequenceStatus.WAITING) - break - - # Can schedule this request. - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() - self._allocate_and_set_running(seq_group) - - if partial_prefill_metadata is not None: - partial_prefill_metadata.maybe_increment_partial_prefills( - seq_group) - - seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens( - seq_group.request_id, - num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens_cached, - ) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - # Queue requests that couldn't be scheduled. - waiting_queue.extendleft(leftover_waiting_sequences) - if len(seq_groups) > 0: - self.prev_prompt = True - - return SchedulerPrefillOutputs( - seq_groups=seq_groups, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), - ) - - def _schedule_default(self) -> SchedulerOutputs: - """Schedule queued requests. - - The current policy is designed to optimize the throughput. First, - it batches as many prefill requests as possible. And it schedules - decodes. If there's a pressure on GPU memory, decode requests can - be swapped or preempted. - """ - # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - # Make sure we include num running seqs before scheduling prefill, - # so that we don't schedule beyond max_num_seqs for prefill. - for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) - curr_loras = (set( - seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None) - - prefills = SchedulerPrefillOutputs.create_empty() - running_scheduled = SchedulerRunningOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # If any requests are swapped, prioritized swapped requests. - if not self.swapped: - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=False) - - if len(prefills.seq_groups - ) == 0 and self.scheduler_config.policy == "priority": - self._schedule_priority_preemption(budget) - - # Don't schedule decodes if prefills are scheduled. - # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running - # only contains decode requests, not chunked prefills. - if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=False) - - # If any sequence group is preempted, do not swap in any sequence - # group. because it means there's no slot for new running requests. - if (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out) == 0): - swapped_in = \ - self._schedule_swapped(budget, curr_loras) - - assert (budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - # Update new running requests. - if len(prefills.seq_groups) > 0: - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - self.running.extend(running_scheduled.decode_seq_groups_list) - - if len(swapped_in.decode_seq_groups) > 0: - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - preempted = len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) - - # There should be no prefill from running queue because this policy - # doesn't allow chunked prefills. - assert len(running_scheduled.prefill_seq_groups) == 0 - assert len(swapped_in.prefill_seq_groups) == 0 - - # Merge lists - num_prefill_groups = len(prefills.seq_groups) - ignored_seq_groups_for_embeds = list[SequenceGroup]() - if num_prefill_groups > 0: - scheduled_seq_groups = prefills.seq_groups - scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) - ignored_seq_groups_for_embeds.clear() - else: - scheduled_seq_groups = running_scheduled.decode_seq_groups - if len(scheduled_seq_groups) > 0: - using_prompt_embeds = scheduled_seq_groups[ - 0].seq_group.uses_prompt_embeds() - ignored_seq_groups_for_embeds.clear() - indices_ignored = list[int]() - for i, schedule_seq_group in enumerate(scheduled_seq_groups): - if using_prompt_embeds !=\ - schedule_seq_group.seq_group.uses_prompt_embeds(): - ignored_seq_groups_for_embeds.append( - schedule_seq_group.seq_group) - indices_ignored.append(i) - if len(ignored_seq_groups_for_embeds) > 0: - scheduled_seq_groups = [ - group for i, group in enumerate(scheduled_seq_groups) - if i not in indices_ignored - ] - else: - ignored_seq_groups_for_embeds.clear() - - scheduled_seq_groups.extend(swapped_in.decode_seq_groups) - - blocks_to_copy = running_scheduled.blocks_to_copy - blocks_to_copy.extend(swapped_in.blocks_to_copy) - - ignored_seq_groups = prefills.ignored_seq_groups - ignored_seq_groups.extend(ignored_seq_groups_for_embeds) - ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) - - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + - budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, - running_queue_size=len(self.running), - preempted=preempted, - ) - - def _schedule_chunked_prefill(self) -> SchedulerOutputs: - """Schedule queued requests. - - Chunked prefill allows to chunk prefill requests, batch them together - with decode requests. This policy 1. schedule as many decoding requests - as possible. 2. schedule chunked prefill requests that are not - finished. 3. schedule swapped request. 4. schedule new prefill - requests. - - The policy can sustain the high GPU utilization because it can put - prefill and decodes requests to the same batch, while it improves - inter token latency because decodes requests don't need to be blocked - by prefill requests. - """ - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - curr_loras: Set[int] = set() - - prefills = SchedulerPrefillOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # Create partial prefill metadata - partial_prefill_metadata = PartialPrefillMetadata.from_queues( - running=self.running, - waiting=self.waiting, - scheduler_config=self.scheduler_config, - ) - - # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running( - budget, - curr_loras, - enable_chunking=True, - partial_prefill_metadata=partial_prefill_metadata, - ) - - # Schedule swapped out requests. - # If preemption happens, it means we don't have space for swap-in. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) - - prefills = self._schedule_prefills( - budget, - curr_loras, - enable_chunking=True, - partial_prefill_metadata=partial_prefill_metadata, - ) - - assert (budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - - # Update new running requests. - # By default, vLLM scheduler prioritizes prefills. - # Once chunked prefill is enabled, - # the policy is changed to prioritize decode requests. - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) - # Because multiple prefills may be running concurrently, we need to - # make sure that prefills which are scheduled to finish are listed - # before those that won't. This is so that on the next scheduling - # iteration when they have transitioned to the decode stage, they are - # properly prioritized over sequences that are still in the prefill - # stage. - self.running.extend( - self._order_finishing_prefills_first( - running_scheduled.prefill_seq_groups)) - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - # Put prefills first due to Attention backend ordering assumption. - scheduled_seq_groups = (prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups) - num_prefill_groups = (len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)) - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + - budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups + - swapped_in.infeasible_seq_groups, - num_lookahead_slots=0, - running_queue_size=len(self.running), - preempted=(len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)), - ) - - def _order_finishing_prefills_first( - self, scheduled_prefill_seqs: List[ScheduledSequenceGroup] - ) -> List[SequenceGroup]: - """Returns a list of prefilling SequenceGroups where sequences that are - scheduled to finish prefilling are listed first""" - finishing = [ - s.seq_group for s in scheduled_prefill_seqs - if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size - ] - not_finishing = [ - s.seq_group for s in scheduled_prefill_seqs - if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size - ] - return finishing + not_finishing - - def _schedule(self) -> SchedulerOutputs: - """Schedule queued requests.""" - if self.scheduler_config.chunked_prefill_enabled: - return self._schedule_chunked_prefill() - else: - return self._schedule_default() - - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: - """Determine whether or not we have enough space in the KV cache to - continue generation of the sequence group. - """ - # It is True only for testing case to trigger artificial preemption. - if (self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0): - self.artificial_preempt_cnt -= 1 - return False - - is_prefill = seq_group.is_prefill() - num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill, enable_chunking) - - return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) - - def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: - # async_output_proc is allowed only when we have a single sequence - # in the sequence group - no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1) - return no_single_seq - - def schedule( - self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. - scheduler_start_time = time.perf_counter() - - scheduler_outputs: SchedulerOutputs = self._schedule() - now = time.time() - - if not self.cache_config.enable_prefix_caching: - common_computed_block_nums = [] - - allow_async_output_proc: bool = self.use_async_output_proc - - # Create input data structures. - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - token_chunk_size = scheduled_seq_group.token_chunk_size - seq_group.maybe_set_first_scheduled_time(now) - - seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id].get_object() - seq_group_metadata.seq_data.clear() - seq_group_metadata.block_tables.clear() - - # seq_id -> SequenceData - seq_data: Dict[int, SequenceData] = {} - # seq_id -> physical block numbers - block_tables: Dict[int, List[int]] = {} - - if seq_group.is_encoder_decoder(): - # Encoder associated with SequenceGroup - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - encoder_seq_data = encoder_seq.data - # Block table for cross-attention - # Also managed at SequenceGroup level - cross_block_table = self.block_manager.get_cross_block_table( - seq_group) - else: - encoder_seq_data = None - cross_block_table = None - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq_id = seq.seq_id - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) - - if self.cache_config.enable_prefix_caching: - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) - - do_sample = True - is_prompt = seq_group.is_prefill() - # We should send the metadata to workers when the first prefill - # is sent. Subsequent requests could be chunked prefill or decode. - is_first_prefill = False - if is_prompt: - seqs = seq_group.get_seqs() - # Prefill has only 1 sequence. - assert len(seqs) == 1 - num_computed_tokens = seqs[0].data.get_num_computed_tokens() - is_first_prefill = num_computed_tokens == 0 - # In the next iteration, all prompt tokens are not computed. - # It means the prefill is chunked, and we don't need sampling. - # NOTE: We use get_len instead of get_prompt_len because when - # a sequence is preempted, prefill includes previous generated - # output tokens. - if (token_chunk_size + num_computed_tokens - < seqs[0].data.get_len()): - do_sample = False - - # It assumes the scheduled_seq_groups is ordered by - # prefill < decoding. - if is_first_prefill or not self.scheduler_config.send_delta_data: - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - do_sample=do_sample, - pooling_params=seq_group.pooling_params, - token_chunk_size=token_chunk_size, - lora_request=seq_group.lora_request, - computed_block_nums=common_computed_block_nums, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - state=seq_group.state, - # `multi_modal_data` will only be present for the 1st comm - # between engine and worker. - # the subsequent comms can still use delta, but - # `multi_modal_data` will be None. - multi_modal_data=(seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups - > 0 else None), - multi_modal_placeholders=( - seq_group.multi_modal_placeholders - if scheduler_outputs.num_prefill_groups > 0 else None), - ) - else: - # When SPMD mode is enabled, we only send delta data except for - # the first request to reduce serialization cost. - seq_data_delta = {} - for id, data in seq_data.items(): - seq_data_delta[id] = data.get_delta_and_reset() - seq_group_metadata = SequenceGroupMetadataDelta( - seq_data_delta, - seq_group.request_id, - block_tables, - is_prompt, - do_sample=do_sample, - token_chunk_size=token_chunk_size, - computed_block_nums=common_computed_block_nums, - ) - seq_group_metadata_list.append(seq_group_metadata) - - if allow_async_output_proc: - allow_async_output_proc = self._allow_async_output_proc( - seq_group) - - # Now that the batch has been created, we can assume all blocks in the - # batch will have been computed before the next scheduling invocation. - # This is because the engine assumes that a failure in model execution - # will crash the vLLM instance / will not retry. - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group, - scheduled_seq_group.token_chunk_size) - - self._seq_group_metadata_cache[self.next_cache_id].reset() - - scheduler_time = time.perf_counter() - scheduler_start_time - # Add this to scheduler time to all the sequences that are currently - # running. This will help estimate if the scheduler is a significant - # component in the e2e latency. - for seq_group in self.running: - if seq_group is not None and seq_group.metrics is not None: - if seq_group.metrics.scheduler_time is not None: - seq_group.metrics.scheduler_time += scheduler_time - else: - seq_group.metrics.scheduler_time = scheduler_time - - # Move to next cache (if exists) - self.cache_id = self.next_cache_id - - # Return results - return (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) - - def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: - self.block_manager.fork(parent_seq, child_seq) - - def free_seq(self, seq: Sequence) -> None: - """Free a sequence from a block table.""" - self.block_manager.free(seq) - - def remove_seq_from_computed_blocks_tracker( - self, seq_group: SequenceGroup, - status: Optional[SequenceStatus]) -> None: - seqs = seq_group.get_seqs(status=status) - for seq in seqs: - self._remove_seq_from_computed_blocks_tracker(seq) - - def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: - """ - Free a sequence computed blocks tracker _seq_id_to_blocks_hashes - and _seq_id_to_num_tokens_computed. - """ - self.block_manager.remove_seq_from_computed_blocks_tracker(seq) - - def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: - """Free finished seqs in a sequence group.""" - for seq in seq_group.get_seqs(): - if seq.is_finished(): - self.free_seq(seq) - - def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - def free_finished_seq_groups(self) -> None: - remaining: Deque[SequenceGroup] = deque() - for seq_group in self.running: - self._free_finished_seq_group(seq_group) - if not seq_group.is_finished(): - remaining.append(seq_group) - - self.running = remaining - - # Handle async stopped sequence groups - # (ones that reached max model len) - if self._async_stopped: - for seq_group in self._async_stopped: - self._free_seq_group_cross_attn_blocks(seq_group) - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - self._async_stopped.clear() - - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - seq.status = SequenceStatus.RUNNING - - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - enable_chunking: bool = False, - ) -> None: - """Appends new slots to the sequences in the given sequence group. - - Args: - seq_group (SequenceGroup): The sequence group containing the - sequences to append slots to. - blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two - ints, the first int is the source block index, and the second - int is the destination block index. This list is updated with - the new source and destination block indices for the appended - slots. - enable_chunking (bool): True if chunked prefill is enabled. - """ - is_prefill: bool = seq_group.is_prefill() - num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking) - - seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING - for seq in seq_group.get_seqs(status=seq_status): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) - if len(cows) > 0: - blocks_to_copy.extend(cows) - - def _preempt(self, seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: - # If preemption mode is not specified, we determine the mode as follows: - # We use recomputation by default since it incurs lower overhead than - # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not currently supported. In - # such a case, we use swapping instead. - # FIXME(woosuk): This makes our scheduling policy a bit bizarre. - # As swapped sequences are prioritized over waiting sequences, - # sequence groups with multiple sequences are implicitly prioritized - # over sequence groups with a single sequence. - # TODO(woosuk): Support recomputation for sequence groups with multiple - # sequences. This may require a more sophisticated CUDA kernel. - if self.user_specified_preemption_mode is None: - if seq_group.get_max_num_running_seqs() == 1: - preemption_mode = PreemptionMode.RECOMPUTE - else: - preemption_mode = PreemptionMode.SWAP - - elif self.user_specified_preemption_mode == "swap": - preemption_mode = PreemptionMode.SWAP - else: - preemption_mode = PreemptionMode.RECOMPUTE - - if self.num_cumulative_preemption % 50 == 0: - logger.warning( - "Sequence group %s is preempted by %s mode because there is " - "not enough KV cache space. This can affect the end-to-end " - "performance. Increase gpu_memory_utilization or " - "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", - seq_group.request_id, - preemption_mode, - self.num_cumulative_preemption + 1, - ) - self.num_cumulative_preemption += 1 - - if preemption_mode == PreemptionMode.RECOMPUTE: - self._preempt_by_recompute(seq_group) - elif preemption_mode == PreemptionMode.SWAP: - self._preempt_by_swap(seq_group, blocks_to_swap_out) - else: - raise AssertionError("Invalid preemption mode.") - return preemption_mode - - def _preempt_by_recompute( - self, - seq_group: SequenceGroup, - ) -> None: - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - assert len(seqs) == 1 - for seq in seqs: - seq.status = SequenceStatus.WAITING - self.free_seq(seq) - seq.reset_state_for_recompute() - self._free_seq_group_cross_attn_blocks(seq_group) - - def _preempt_by_swap( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - ) -> None: - self._swap_out(seq_group, blocks_to_swap_out) - - def _swap_in( - self, - seq_group: SequenceGroup, - blocks_to_swap_in: List[Tuple[int, int]], - ) -> None: - mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - seq.status = SequenceStatus.RUNNING - - def _swap_out( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - ) -> None: - if not self.block_manager.can_swap_out(seq_group): - # FIXME(woosuk): Abort the sequence group instead of aborting the - # entire engine. - raise RuntimeError( - "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") - mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.SWAPPED - - def _passed_delay(self, now: float) -> bool: - if self.prev_prompt: - self.last_prompt_latency = now - self.prev_time - self.prev_time, self.prev_prompt = now, False - # Delay scheduling prompts to let waiting queue fill up - if self.scheduler_config.delay_factor > 0 and self.waiting: - earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ((now - earliest_arrival_time) - > (self.scheduler_config.delay_factor * - self.last_prompt_latency) or not self.running) - else: - passed_delay = True - return passed_delay - - def _get_num_lookahead_slots(self, is_prefill: bool, - enable_chunking: bool) -> int: - """The number of slots to allocate per sequence per step, beyond known - token ids. Speculative decoding uses these slots to store KV activations - of tokens which may or may not be accepted. - """ - return 0 - - def _get_num_new_uncached_and_cached_tokens( - self, - seq_group: SequenceGroup, - status: SequenceStatus, - enable_chunking: bool, - budget: SchedulingBudget, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> Tuple[int, int]: - """ - Returns the number of new uncached and cached tokens to schedule for a - given sequence group that's in a given `status`. - - The API could chunk the number of tokens to compute based on `budget` - if `enable_chunking` is True. If a sequence group has multiple - sequences (e.g., running beam search), it means it is in decoding - phase, so chunking doesn't happen. - - Returns (0, 0) if the new token cannot be computed due to token budget. - - The cached tokens's blocks are already computed, and the attention - backend will reuse the cached blocks rather than recomputing them. So - the scheduler could schedule these cached tokens "for free". - - Args: - seq_group: The sequence group to get the number of new tokens to - schedule. - status: The status of the sequences to get the number of new tokens - to schedule. - enable_chunking: Whether to chunk the number of tokens to compute. - budget: The budget to chunk the number of tokens to compute. - partial_prefill_metadata: information about the partial prefills - that are currently running - - - Returns: - A tuple of two ints. The first int is the number of new uncached - tokens to schedule. The second int is the number of cached tokens. - If no more new tokens can be scheduled, returns (0, 0). - """ - num_cached_new_tokens = 0 - num_uncached_new_tokens = 0 - - seqs = seq_group.get_seqs(status=status) - # Compute the number of new uncached and cached tokens for - # each sequence. - for seq in seqs: - if not seq.is_prefill(): - # Decode sequences should always just have 1 uncached token - # TODO(rickyx): Actually is this still correct for multi-step? - num_uncached_new_tokens += 1 - continue - - num_computed_tokens_seq = seq.get_num_computed_tokens() - all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq - if not self.cache_config.enable_prefix_caching: - # If prefix caching is not enabled, all new tokens are uncached. - num_uncached_new_tokens += all_num_new_tokens_seq - continue - - # NOTE: the cache token might be currently in a block that's in an - # evictor meaning that it's not yet allocated. However, we don't - # exclude such tokens in the cache count because it will be - # guaranteed to be allocated later if the sequence can be allocated. - num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( - seq) - - # Sanity check. - if num_cached_tokens_seq < num_computed_tokens_seq: - # This should only happen with chunked prefill, and - # the seq is still in prefill. The `num_cached_tokens_seq` - # is the value we calculated on scheduling the first prefill. - # For subsequent continuous prefill steps, we cached the - # number of cache tokens for the sequence so the cached token - # count could be less than the number of computed tokens. - # See comments on `ComputedBlocksTracker` for more details. - assert ( - seq.is_prefill() and seq.status == SequenceStatus.RUNNING - and self.scheduler_config.chunked_prefill_enabled - ), ("Number of cached tokens should not be less than the " - "number of computed tokens for a sequence that's still " - f"in prefill. But there are {num_cached_tokens_seq} cached " - f"tokens and {num_computed_tokens_seq} computed tokens " - f"for sequence {seq.seq_id}.") - - num_cached_new_tokens_seq = max( - 0, num_cached_tokens_seq - num_computed_tokens_seq) - num_uncached_new_tokens_seq = (all_num_new_tokens_seq - - num_cached_new_tokens_seq) - - num_uncached_new_tokens += num_uncached_new_tokens_seq - num_cached_new_tokens += num_cached_new_tokens_seq - - if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: - # For a fully cached hit sequence, we actually need to recompute the - # last token. So we need at least 1 uncached token to schedule. - # See ModelRunner._compute_for_prefix_cache_hit for more details. - num_uncached_new_tokens = 1 - num_cached_new_tokens -= 1 - - if enable_chunking and len(seqs) == 1: - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. - num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( - self.scheduler_config, - self.cache_config, - budget, - self._get_prompt_limit(seq_group), - num_uncached_new_tokens, - self.partial_prefill_budget_lookup_list, - partial_prefill_metadata, - ) - - return num_uncached_new_tokens, num_cached_new_tokens - - @staticmethod - def _chunk_new_tokens_to_schedule( - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - budget: SchedulingBudget, - prompt_limit: int, - num_new_tokens: int, - partial_prefill_budget_lookup_list: List[int], - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> int: - """ - Chunks the number of new tokens to schedule based on the budget when - chunked prefill is enabled. - - Args: - scheduler_config: The scheduler config. - cache_config: The cache config. - budget: The budget to chunk the number of tokens to compute. - prompt_limit: The maximum number of tokens allowed in a prompt. - num_new_tokens: The number of new tokens to schedule. - - Returns: - The number of new tokens to schedule after chunking. - """ - remaining_token_budget = budget.remaining_token_budget() - - # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = ( - remaining_token_budget if partial_prefill_metadata is None else - partial_prefill_budget_lookup_list[ - partial_prefill_metadata.schedulable_prefills]) - - if cache_config.enable_prefix_caching: - # When prefix caching is enabled and we're partially prefilling - # a sequence, we always allocate a number of new tokens that is - # divisible by the block size to avoid partial block matching. - block_size = cache_config.block_size - # Don't exceed either the total budget or slot budget. - # Take min of those and get the next lowest multiple of the - # block size: - remaining_token_budget = ( - min(remaining_token_budget, prefill_slot_budget) // - block_size) * block_size - # NB: In the case where num_new_tokens < budget, we are - # finishing prefill for this sequence, so we do not need to - # allocate a full block. - - num_new_tokens = min(num_new_tokens, remaining_token_budget, - prefill_slot_budget) - - return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7a4bb0d41d236..7e00260caa39a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,7 +41,8 @@ from vllm.plugins import load_general_plugins from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import get_model_path, is_interleaved +from vllm.transformers_utils.config import (get_model_path, is_interleaved, + maybe_override_with_speculators) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) @@ -409,9 +410,7 @@ class EngineArgs: get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns - preemption_mode: Optional[str] = SchedulerConfig.preemption_mode - scheduler_delay_factor: float = SchedulerConfig.delay_factor enable_chunked_prefill: Optional[ bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input @@ -439,7 +438,6 @@ class EngineArgs: ObservabilityConfig.otlp_traces_endpoint collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ ObservabilityConfig.collect_detailed_traces - disable_async_output_proc: bool = not ModelConfig.use_async_output_proc scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls @@ -561,14 +559,6 @@ class EngineArgs: **model_kwargs["enable_prompt_embeds"]) model_group.add_argument("--served-model-name", **model_kwargs["served_model_name"]) - # This one is a special case because it is the - # opposite of ModelConfig.use_async_output_proc - model_group.add_argument( - "--disable-async-output-proc", - action="store_true", - default=EngineArgs.disable_async_output_proc, - help="Disable async output processing. This may result in " - "lower performance.") model_group.add_argument("--config-format", **model_kwargs["config_format"]) # This one is a special case because it can bool @@ -897,10 +887,6 @@ class EngineArgs: **scheduler_kwargs["long_prefill_token_threshold"]) scheduler_group.add_argument("--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]) - scheduler_group.add_argument("--scheduler-delay-factor", - **scheduler_kwargs["delay_factor"]) - scheduler_group.add_argument("--preemption-mode", - **scheduler_kwargs["preemption_mode"]) # multi-step scheduling has been removed; corresponding arguments # are no longer supported. scheduler_group.add_argument("--scheduling-policy", @@ -1029,7 +1015,6 @@ class EngineArgs: interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, skip_mm_profiling=self.skip_mm_profiling, - use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, @@ -1098,29 +1083,8 @@ class EngineArgs: provided as a JSON string input via CLI arguments or directly as a dictionary from the engine. """ - - from vllm.transformers_utils.config import get_config - from vllm.transformers_utils.configs.speculators.base import ( - SpeculatorsConfig) - if self.speculative_config is None: - hf_config = get_config( - self.hf_config_path or target_model_config.model, - self.trust_remote_code, self.revision, self.code_revision, - self.config_format) - - # if loading a SpeculatorsConfig, load the speculative_config - # details from the config directly - # no user input required / expected - if isinstance(hf_config, SpeculatorsConfig): - # We create one since we don't create one - self.speculative_config = {} - self.speculative_config[ - "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = target_model_config.model - self.speculative_config["method"] = hf_config.method - else: - return None + return None # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine @@ -1155,6 +1119,15 @@ class EngineArgs: device_config = DeviceConfig( device=cast(Device, current_platform.device_type)) + + (self.model, self.tokenizer, + self.speculative_config) = maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) model_config = self.create_model_config() # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" @@ -1395,11 +1368,9 @@ class EngineArgs: max_model_len=model_config.max_model_len, cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, - delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, - preemption_mode=self.preemption_mode, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, @@ -1486,42 +1457,12 @@ class EngineArgs: ############################################################# # Unsupported Feature Flags on V1. - if self.load_format == "sharded_state": - _raise_or_fallback( - feature_name=f"--load_format {self.load_format}", - recommend_to_remove=False) - return False - if (self.logits_processor_pattern != EngineArgs.logits_processor_pattern): _raise_or_fallback(feature_name="--logits-processor-pattern", recommend_to_remove=False) return False - if self.preemption_mode != SchedulerConfig.preemption_mode: - _raise_or_fallback(feature_name="--preemption-mode", - recommend_to_remove=True) - return False - - if (self.disable_async_output_proc - != EngineArgs.disable_async_output_proc): - _raise_or_fallback(feature_name="--disable-async-output-proc", - recommend_to_remove=True) - return False - - if self.scheduler_delay_factor != SchedulerConfig.delay_factor: - _raise_or_fallback(feature_name="--scheduler-delay-factor", - recommend_to_remove=True) - return False - - if self.kv_cache_dtype != "auto": - supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype, model_config) - if not supported: - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False - # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 014bc56bc8ece..a0fe38eb320d6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,1835 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from collections import Counter as collectionsCounter -from collections import deque -from contextlib import contextmanager -from dataclasses import dataclass -from functools import partial -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Literal, Mapping, NamedTuple, Optional) -from typing import Sequence as GenericSequence -from typing import Set, Type, Union, cast +from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine -import torch -import torch.nn as nn -from typing_extensions import TypeVar - -import vllm.envs as envs -from vllm.config import (LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase, Stats -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.entrypoints.openai.logits_processors import ( - get_logits_processors as get_openai_logits_processors) -from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs -from vllm.inputs.parse import split_enc_dec_inputs -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.logits_process import get_bad_words_logits_processors -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.cache import processor_only_cache_from_config -from vllm.multimodal.processing import EncDecMultiModalProcessor -from vllm.outputs import (PoolingRequestOutput, RequestOutput, - RequestOutputFactory) -from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, - Sequence, SequenceGroup, SequenceGroupBase, - SequenceGroupMetadata, SequenceGroupOutput, - SequenceStatus) -from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, - init_tracer) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - init_tokenizer_from_configs) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind -from vllm.version import __version__ as VLLM_VERSION -from vllm.worker.model_runner_base import InputProcessingError -from vllm.worker.worker_base import WorkerBase - -logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - -_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) -_R = TypeVar("_R", default=Any) - - -@dataclass -class SchedulerOutputState: - """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - allow_async_output_proc: bool = False - last_output: Optional[SamplerOutput] = None - - -class OutputData(NamedTuple): - outputs: List[SamplerOutput] - seq_group_metadata_list: List[SequenceGroupMetadata] - scheduler_outputs: SchedulerOutputs - is_async: bool - is_last_step: bool - # Indicates if this output is from the first step of the - # multi-step. When multi-step is disabled, this is always - # set to True. - # is_first_step_output is invalid when `outputs` has - # outputs from multiple steps. - is_first_step_output: Optional[bool] - skip: List[int] - - -class SchedulerContext: - - def __init__(self) -> None: - self.output_queue: Deque[OutputData] = deque() - self.request_outputs: List[RequestOutput] = [] - self.seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None - self.scheduler_outputs: Optional[SchedulerOutputs] = None - - def append_output(self, outputs: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool, - is_first_step_output: Optional[bool]): - self.output_queue.append( - OutputData(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=is_async, - is_last_step=is_last_step, - is_first_step_output=is_first_step_output, - skip=[])) - - -class LLMEngine: - """An LLM engine that receives requests and generates texts. - - This is the main class for the vLLM engine. It receives requests - from clients and generates texts from the LLM. It includes a tokenizer, a - language model (possibly distributed across multiple GPUs), and GPU memory - space allocated for intermediate states (aka KV cache). This class utilizes - iteration-level scheduling and efficient memory management to maximize the - serving throughput. - - The [`LLM`][vllm.LLM] class wraps this class for offline batched inference - and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine] - class wraps this class for online serving. - - The config arguments are derived from [`EngineArgs`][vllm.EngineArgs]. - - Args: - vllm_config: The configuration for initializing and running vLLM. - executor_class: The model executor class for managing distributed - execution. - log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection. - """ - - DO_VALIDATE_OUTPUT: ClassVar[bool] = False - """A flag to toggle whether to validate the type of request output.""" - - @classmethod - @contextmanager - def enable_output_validation(cls): - cls.DO_VALIDATE_OUTPUT = True - - yield - - cls.DO_VALIDATE_OUTPUT = False - - @classmethod - def validate_output( - cls, - output: object, - output_type: Type[_O], - ) -> _O: - do_validate = cls.DO_VALIDATE_OUTPUT - - if ((TYPE_CHECKING or do_validate) - and not isinstance(output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - return cast(_O, output) - - @classmethod - def validate_outputs( - cls, - outputs: GenericSequence[object], - output_type: Type[_O], - ) -> List[_O]: - do_validate = cls.DO_VALIDATE_OUTPUT - - outputs_: List[_O] - if TYPE_CHECKING or do_validate: - outputs_ = [] - for output in outputs: - if not isinstance(output, output_type): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - outputs_.append(output) - else: - outputs_ = outputs - - return outputs_ - - tokenizer: Optional[AnyTokenizer] - - def __init__( - self, - vllm_config: VllmConfig, - executor_class: Type[ExecutorBase], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - use_cached_outputs: bool = False, - ) -> None: - if envs.VLLM_USE_V1: - raise ValueError( - "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. " - "This should not happen. As a workaround, try using " - "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") - - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config # noqa - self.load_config = vllm_config.load_config - self.structured_outputs_config = vllm_config.structured_outputs_config - self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa - ) - - logger.info( - "Initializing a V0 LLM engine (v%s) with config: %s, " - "use_cached_outputs=%s, ", - VLLM_VERSION, - vllm_config, - use_cached_outputs, - ) - - self.log_stats = log_stats - self.use_cached_outputs = use_cached_outputs - - if self.model_config.skip_tokenizer_init: - self.tokenizer = None - self.detokenizer = None - else: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - - self.seq_counter = Counter() - self.generation_config_fields = ( - self.model_config.try_get_generation_config()) - - self.input_preprocessor = InputPreprocessor( - self.model_config, - self.tokenizer, - mm_registry, - mm_processor_cache=processor_only_cache_from_config( - self.model_config, mm_registry), - ) - - self.model_executor = executor_class(vllm_config=vllm_config) - - self._initialize_kv_caches() - - # If usage stat is enabled, collect relevant info. - if is_usage_stats_enabled(): - from vllm.model_executor.model_loader import ( - get_architecture_class_name) - usage_message.report_usage( - get_architecture_class_name(self.model_config), - usage_context, - extra_kvs={ - # Common configuration - "dtype": - str(self.model_config.dtype), - "tensor_parallel_size": - self.parallel_config.tensor_parallel_size, - "block_size": - self.cache_config.block_size, - "gpu_memory_utilization": - self.cache_config.gpu_memory_utilization, - "kv_cache_memory_bytes": - self.cache_config.kv_cache_memory_bytes, - # Quantization - "quantization": - self.model_config.quantization, - "kv_cache_dtype": - str(self.cache_config.cache_dtype), - - # Feature flags - "enable_lora": - bool(self.lora_config), - "enable_prefix_caching": - self.cache_config.enable_prefix_caching, - "enforce_eager": - self.model_config.enforce_eager, - "disable_custom_all_reduce": - self.parallel_config.disable_custom_all_reduce, - }) - - self.cached_scheduler_outputs = [ - SchedulerOutputState() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.scheduler_contexts = [ - SchedulerContext() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - if self.model_config.use_async_output_proc: - process_model_outputs = weak_bind(self._process_model_outputs) - - self.async_callbacks = [ - partial(process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - else: - self.async_callbacks = [] - - # Currently used by AsyncLLMEngine to ensure quick append - # of request outputs to asyncio queues - self.process_request_outputs_callback: Optional[Callable] = None - - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): - Scheduler = resolve_obj_by_qualname( - self.vllm_config.scheduler_config.scheduler_cls) - else: - Scheduler = self.vllm_config.scheduler_config.scheduler_cls - self.scheduler = [ - Scheduler( - self.scheduler_config, self.cache_config, self.lora_config, - self.parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] - if self.model_config.use_async_output_proc else None) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - - # Metric Logging. - if self.log_stats: - if stat_loggers is not None: - self.stat_loggers = stat_loggers - else: - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from vllm.engine.metrics import (LoggingStatLogger, - PrometheusStatLogger) - - self.stat_loggers = { - "logging": - LoggingStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - vllm_config=vllm_config), - "prometheus": - PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict( - model_name=self.model_config.served_model_name), - vllm_config=vllm_config), - } - self.stat_loggers["prometheus"].info("cache_config", - self.cache_config) - - self.tracer = None - if self.observability_config.otlp_traces_endpoint: - self.tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) - - # Initialize reasoning parser if reasoning backend is set. - if self.structured_outputs_config.reasoning_parser and self.tokenizer: - reasoner_class = ReasoningParserManager.get_reasoning_parser( - self.structured_outputs_config.reasoning_parser) - self.reasoner: ReasoningParser = reasoner_class( - self.tokenizer.get_lora_tokenizer()) - - # Create sequence output processor, e.g. for beam search or - # speculative decoding. - self.output_processor = ( - SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - self.reasoner - if self.structured_outputs_config.reasoning_parser - and self.tokenizer else None, - ), - )) - - self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} - - # Flag to set when an input fails to process and the engine should run - # the next step without re-scheduling. - self._skip_scheduling_next_step = False - - # Don't keep the dummy data in memory - self.reset_mm_cache() - - def _initialize_kv_caches(self) -> None: - """Initialize the KV cache in the worker(s). - - The workers will determine the number of blocks in both the GPU cache - and the swap CPU cache. - """ - start = time.time() - num_gpu_blocks, num_cpu_blocks = ( - self.model_executor.determine_num_available_blocks()) - - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) - elapsed = time.time() - start - logger.info(("init engine (profile, create kv cache, " - "warmup model) took %.2f seconds"), elapsed) - - @classmethod - def _get_executor_cls(cls, - engine_config: VllmConfig) -> Type[ExecutorBase]: - # distributed_executor_backend must be set in VllmConfig.__post_init__ - distributed_executor_backend = ( - engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. - if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorBase): - raise TypeError( - "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") - executor_class = distributed_executor_backend - elif distributed_executor_backend == "ray": - from vllm.executor.ray_distributed_executor import ( - RayDistributedExecutor) - executor_class = RayDistributedExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.mp_distributed_executor import ( - MultiprocessingDistributedExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingDistributedExecutor - elif distributed_executor_backend == "uni": - # JAX-style, single-process, multi-device executor. - from vllm.executor.uniproc_executor import UniProcExecutor - executor_class = UniProcExecutor - elif distributed_executor_backend == "external_launcher": - # executor with external launcher - from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher) - executor_class = ExecutorWithExternalLauncher - else: - raise ValueError("unrecognized distributed_executor_backend: " - f"{distributed_executor_backend}") - return executor_class - - @classmethod - def from_vllm_config( - cls, - vllm_config: VllmConfig, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - disable_log_stats: bool = False, - ) -> "LLMEngine": - return cls( - vllm_config=vllm_config, - executor_class=cls._get_executor_cls(vllm_config), - log_stats=(not disable_log_stats), - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - - @classmethod - def from_engine_args( - cls, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - vllm_config = engine_args.create_engine_config(usage_context) - - engine_cls = cls - if envs.VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - engine_cls = V1LLMEngine - - return engine_cls.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - stat_loggers=stat_loggers, - disable_log_stats=engine_args.disable_log_stats, - ) - - def __reduce__(self): - # This is to ensure that the LLMEngine is not referenced in - # the closure used to initialize Ray worker actors - raise RuntimeError("LLMEngine should not be pickled!") - - def __del__(self): - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): - model_executor.shutdown() - - def get_tokenizer(self) -> AnyTokenizer: - if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") - - return self.tokenizer - - def _init_tokenizer(self) -> AnyTokenizer: - return init_tokenizer_from_configs(model_config=self.model_config) - - def _verify_args(self) -> None: - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.lora_config: - self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) - - def _add_processed_request( - self, - request_id: str, - processed_inputs: ProcessorInputs, - params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> Optional[SequenceGroup]: - """Add a processed request to the engine's request pool. - return the created sequence group. - """ - if isinstance(params, SamplingParams) and params.n > 1: - ParallelSampleSequenceGroup.add_request( - request_id, - self, - params, - processed_inputs=processed_inputs, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ) - return None - - self._validate_model_inputs(processed_inputs) - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = self.input_preprocessor.get_eos_token_id() - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - - seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, - lora_request) - - encoder_seq = (None if encoder_inputs is None else Sequence( - seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) - - # Create a SequenceGroup based on SamplingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - encoder_seq=encoder_seq, - priority=priority) - else: - raise ValueError("SamplingParams must be provided.") - - # Add the sequence group to the scheduler with least unfinished seqs. - costs = [ - scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler - ] - min_cost_scheduler = self.scheduler[costs.index(min(costs))] - min_cost_scheduler.add_seq_group(seq_group) - - return seq_group - - def stop_remote_worker_execution_loop(self) -> None: - self.model_executor.stop_remote_worker_execution_loop() - - def add_request( - self, - request_id: str, - prompt: PromptType, - params: SamplingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> None: - """Add a request to the engine's request pool. - - The request is added to the request pool and will be processed by the - scheduler as `engine.step()` is called. The exact scheduling policy is - determined by the scheduler. - - Args: - request_id: The unique ID of the request. - prompt: The prompt to the LLM. See - [PromptType][vllm.inputs.PromptType] - for more details about the format of each input. - params: Parameters for sampling. - [SamplingParams][vllm.SamplingParams] for text generation. - arrival_time: The arrival time of the request. If None, we use - the current monotonic time. - lora_request: The LoRA request to add. - trace_headers: OpenTelemetry trace headers. - priority: The priority of the request. - Only applicable with priority scheduling. - - Details: - - Set arrival_time to the current time if it is None. - - Set prompt_token_ids to the encoded prompt if it is None. - - Create `n` number of [Sequence][vllm.sequence.Sequence] objects. - - Create a [SequenceGroup][vllm.sequence.SequenceGroup] object - from the list of [Sequence][vllm.sequence.Sequence]. - - Add the [SequenceGroup][vllm.sequence.SequenceGroup] object to the - scheduler. - - Example: - >>> # initialize engine - >>> engine = LLMEngine.from_engine_args(engine_args) - >>> # set request arguments - >>> example_prompt = "Who is the president of the United States?" - >>> sampling_params = SamplingParams(temperature=0.0) - >>> request_id = 0 - >>> - >>> # add the request to the engine - >>> engine.add_request( - >>> str(request_id), - >>> example_prompt, - >>> SamplingParams(temperature=0.0)) - >>> # continue the request processing - >>> ... - """ - if not isinstance(request_id, str): - raise TypeError( - f"request_id must be a string, got {type(request_id)}") - - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - - if priority != 0 and not self.scheduler_config.policy == "priority": - raise ValueError(f"Got priority {priority} but " - "Priority scheduling is not enabled.") - - if isinstance(params, SamplingParams) \ - and params.logits_processors: - raise ValueError( - "Logits processors are not supported in multi-step decoding") - - if arrival_time is None: - arrival_time = time.time() - - if (isinstance(prompt, dict) - and prompt.get("prompt_embeds", None) is not None): - if not prompt.get("prompt_token_ids", None): - seq_len = prompt["prompt_embeds"].shape[0] - prompt["prompt_token_ids"] = [0] * seq_len - if params.prompt_logprobs is not None: - raise ValueError( - "prompt_logprobs is not compatible with prompt embeds.") - - processed_inputs = self.input_preprocessor.preprocess( - prompt, - tokenization_kwargs=tokenization_kwargs, - ) - - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - ) - - def _create_sequence_group_with_sampling( - self, - request_id: str, - seq: Sequence, - sampling_params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - encoder_seq: Optional[Sequence] = None, - priority: int = 0, - ) -> SequenceGroup: - """Creates a SequenceGroup with SamplingParams.""" - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") - - sampling_params = self._build_logits_processors( - sampling_params, lora_request) - - # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.clone() - - sampling_params.update_from_generation_config( - self.generation_config_fields, seq.eos_token_id) - - # Create the sequence group. - draft_size = 1 - if self.vllm_config.speculative_config is not None: - draft_size = \ - self.vllm_config.speculative_config.num_speculative_tokens + 1 - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - encoder_seq=encoder_seq, - priority=priority, - draft_size=draft_size) - - return seq_group - - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - """Aborts a request(s) with the given ID. - - Args: - request_id: The ID(s) of the request to abort. - - Details: - - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][]. - - Example: - >>> # initialize engine and add a request with request_id - >>> request_id = str(0) - >>> # abort the request - >>> engine.abort_request(request_id) - """ - for scheduler in self.scheduler: - scheduler.abort_seq_group( - request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) - - def get_vllm_config(self) -> VllmConfig: - """Gets the vllm configuration.""" - return self.vllm_config - - def get_model_config(self) -> ModelConfig: - """Gets the model configuration.""" - return self.model_config - - def get_parallel_config(self) -> ParallelConfig: - """Gets the parallel configuration.""" - return self.parallel_config - - def get_scheduler_config(self) -> SchedulerConfig: - """Gets the scheduler configuration.""" - return self.scheduler_config - - def get_lora_config(self) -> LoRAConfig: - """Gets the LoRA configuration.""" - return self.lora_config - - def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return sum(scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler) - - def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return any(scheduler.has_unfinished_seqs() - for scheduler in self.scheduler) - - def has_unfinished_requests_for_virtual_engine( - self, virtual_engine: int) -> bool: - """ - Returns True if there are unfinished requests for the virtual engine. - """ - return self.scheduler[virtual_engine].has_unfinished_seqs() - - def reset_mm_cache(self) -> bool: - """Reset the multi-modal cache.""" - self.input_preprocessor.clear_cache() - return True - - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: - """Reset prefix cache for all devices.""" - - success = True - for scheduler in self.scheduler: - success = success and scheduler.reset_prefix_cache(device) - return success - - def _process_model_outputs(self, - ctx: SchedulerContext, - request_id: Optional[str] = None) -> None: - """Apply the model output to the sequences in the scheduled seq groups - and return responses. - - ctx: The virtual engine context to work on - request_id: If provided, then only this request is going to be processed - """ - - now = time.time() - - if len(ctx.output_queue) == 0: - return None - - # Get pending async postprocessor - if request_id: - # When we process only one request, no pop is required - # (since later we will process all of the rest) - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, skip) = ctx.output_queue[0] - else: - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, - skip) = ctx.output_queue.popleft() - - # Sanity check - assert len(seq_group_metadata_list) == len( - scheduler_outputs.scheduled_seq_groups) - - has_multiple_outputs: bool = len(outputs) > 1 - outputs_by_sequence_group: List[List[SequenceGroupOutput]] - assert not has_multiple_outputs - outputs_by_sequence_group = outputs - - # Determine the requests we need to operate on - if request_id: - indices = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): - if seq_group_meta.request_id == request_id: - assert i not in skip # Cannot be called twice - indices.append(i) - break - - # If the request_id was not found, then it means that - # this is a new request that has no pending async - # postprocessor - if not indices: - return - else: - indices = range(len(seq_group_metadata_list)) # type: ignore - - finished_before: List[int] = [] - finished_now: List[int] = [] - for i in indices: - if i in skip: - continue - - seq_group_meta = seq_group_metadata_list[i] - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group: SequenceGroup = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - finished_before.append(i) - continue - - output: List[SequenceGroupOutput] - if has_multiple_outputs: - output = outputs_by_sequence_group[i] - else: - output = [outputs_by_sequence_group[0][i]] - - if not is_async: - seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size or 0) - - if outputs: - for o in outputs: - if (isinstance(o, SamplerOutput) - and seq_group.metrics is not None): - if seq_group.metrics.model_forward_time is not None: - seq_group.metrics.model_forward_time += ( - o.model_forward_time or 0) - else: - seq_group.metrics.model_forward_time = ( - o.model_forward_time) - if seq_group.metrics.model_execute_time is not None: - seq_group.metrics.model_execute_time += ( - o.model_execute_time or 0) - else: - seq_group.metrics.model_execute_time = ( - o.model_execute_time) - - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs(seq_group, output, - is_async) - - if seq_group.is_finished(): - finished_now.append(i) - - # Generate outputs for the requests that finished this iteration - for i in finished_now: - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - if not seq_group.is_prefill(): - seq_group.set_last_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs) - if request_output: - ctx.request_outputs.append(request_output) - - # When we process a single request, we skip it for the next time, - # and invoke the request output callback (if there was final output) - if request_id: - assert len(indices) == 1 - skip.append(indices[0]) - - if (finished_now - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - return - - # Free currently finished requests - if finished_now: - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() - - # Create the outputs - for i in indices: - if i in skip or i in finished_before or i in finished_now: - continue # Avoids double processing - - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - - seq_group = scheduled_seq_group.seq_group - seq_group.maybe_set_first_token_time(now) - if not seq_group.is_prefill(): - seq_group.set_last_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs) - if request_output: - ctx.request_outputs.append(request_output) - - # Create outputs only after processing the scheduler's results - - for seq_group in scheduler_outputs.ignored_seq_groups: - params = seq_group.sampling_params - if params is not None and params.output_kind == ( - RequestOutputKind.DELTA) and not seq_group.is_finished(): - continue - - request_output = RequestOutputFactory.create( - seq_group, - self.seq_id_to_seq_group, - use_cache=self.use_cached_outputs, - ) - if request_output: - ctx.request_outputs.append(request_output) - - # Immediately process request outputs here (if callback is given) - if (ctx.request_outputs - and self.process_request_outputs_callback is not None): - self.process_request_outputs_callback(ctx.request_outputs) - ctx.request_outputs.clear() - - # For async case, we need to record the stats here. - # For non-async case, the stats are done in the - # LLMEngine/AsyncLLMEngine directly - if is_async: - # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before, - skip) - - # Tracing - self.do_tracing(scheduler_outputs, finished_before) - - return None - - def _advance_to_next_step( - self, output: SamplerOutput, - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: - """Given model output from a single run, append the tokens to the - sequences. This is normally done inside output processor, but it is - required if the worker is to perform async forward pass to next step. - """ - for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ - zip(seq_group_metadata_list, output, scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - - if seq_group.is_finished(): - continue - - token_chunk_size = (seq_group_metadata.token_chunk_size - if seq_group_metadata.token_chunk_size - is not None else 0) - seq_group.update_num_computed_tokens(token_chunk_size) - - if seq_group_metadata.do_sample: - assert len(sequence_group_outputs.samples) == 1, ( - "Async output processor expects a single sample" - " (i.e sampling_params.n == 1)") - sample = sequence_group_outputs.samples[0] - - assert len(seq_group.seqs) == 1 - seq = seq_group.seqs[0] - - seq.append_token_id(sample.output_token, sample.logprobs, - sample.output_embed) - - def step(self) -> List[RequestOutput]: - """Performs one decoding iteration and returns newly generated results. - -
- ![Overview of the step function](https://i.imgur.com/sv2HssD.png) -
Overview of the step function
-
- - Details: - - Step 1: Schedules the sequences to be executed in the next - iteration and the token blocks to be swapped in/out/copy. - - - Depending on the scheduling policy, - sequences may be `preempted/reordered`. - - A Sequence Group (SG) refer to a group of sequences - that are generated from the same prompt. - - - Step 2: Calls the distributed executor to execute the model. - - Step 3: Processes the model output. This mainly includes: - - - Decodes the relevant outputs. - - Updates the scheduled sequence groups with model outputs - based on its `sampling parameters` (`use_beam_search` or not). - - Frees the finished sequence groups. - - - Finally, it creates and returns the newly generated results. - - Example: - ``` - # Please see the example/ folder for more detailed examples. - - # initialize engine and request arguments - engine = LLMEngine.from_engine_args(engine_args) - example_inputs = [(0, "What is LLM?", - SamplingParams(temperature=0.0))] - - # Start the engine with an event loop - while True: - if example_inputs: - req_id, prompt, sampling_params = example_inputs.pop(0) - engine.add_request(str(req_id),prompt,sampling_params) - - # continue the request processing - request_outputs = engine.step() - for request_output in request_outputs: - if request_output.finished: - # return or show the request output - - if not (engine.has_unfinished_requests() or example_inputs): - break - ``` - """ - if self.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported through AsyncLLMEngine " - "as performance will be severely degraded otherwise.") - - # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0. - virtual_engine = 0 - - # These are cached outputs from previous iterations. None if on first - # iteration - cached_outputs = self.cached_scheduler_outputs[virtual_engine] - seq_group_metadata_list = cached_outputs.seq_group_metadata_list - scheduler_outputs = cached_outputs.scheduler_outputs - allow_async_output_proc = cached_outputs.allow_async_output_proc - - ctx = self.scheduler_contexts[virtual_engine] - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() - - # Skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. - # The scheduler is also skipped if a single request caused the last - # engine step to fail, and the previous schedule needs to be rerun. - if not self._has_remaining_steps( - seq_group_metadata_list - ) and not self._skip_scheduling_next_step: - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs - - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() - # When n>1, elements in self.seq_id_to_seq_group should be deleted - # here, otherwise memory leaks. - for finished_request_id in finished_requests_ids: - if finished_request_id in self.seq_id_to_seq_group: - del self.seq_id_to_seq_group[finished_request_id] - - # Maybe switch from async mode to sync mode - if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - - else: - finished_requests_ids = list() - - assert seq_group_metadata_list is not None - assert scheduler_outputs is not None - - if not scheduler_outputs.is_empty(): - - # Check if we have a cached last_output from the previous iteration. - # For supporting PP this is probably the best way to pass the - # sampled_token_ids, as a separate broadcast over all the PP stages - # will cause one virtual engine's microbatch to block the pipeline. - last_sampled_token_ids = \ - self._get_last_sampled_token_ids(virtual_engine) - - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots, - running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids, - # We use ExecuteModelRequest to pass the last sampled_token_ids - # to each of the non-last PP stages for in-place prepare_input. - last_sampled_token_ids=last_sampled_token_ids) - - if allow_async_output_proc: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] - - try: - outputs = self.model_executor.execute_model( - execute_model_req=execute_model_req) - self._skip_scheduling_next_step = False - except InputProcessingError as e: - # The input for this request cannot be processed, so we must - # abort it. If there are remaining requests in the batch that - # have been scheduled, they will be retried on the next step. - invalid_request_id = e.request_id - self._abort_and_cache_schedule( - request_id=invalid_request_id, - virtual_engine=virtual_engine, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - allow_async_output_proc=allow_async_output_proc) - # Raise so the caller is notified that this request failed - raise - - else: - # Nothing scheduled => If there is pending async postprocessor, - # then finish it here. - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - # No outputs in this case - outputs = [] - - if not self._has_remaining_steps(seq_group_metadata_list): - # is_first_step_output is True only when the num_steps of all - # the sequences are 1. - is_first_step_output: bool = False if not seq_group_metadata_list \ - else seq_group_metadata_list[0].state.num_steps == 1 - - # Add results to the output_queue - ctx.append_output(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, - scheduler_outputs=scheduler_outputs, - is_async=allow_async_output_proc, - is_last_step=True, - is_first_step_output=is_first_step_output) - - if outputs and allow_async_output_proc: - assert len(outputs) == 1, ( - "Async postprocessor expects only a single output set") - - self._advance_to_next_step( - outputs[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) - - # Check if need to run the usual non-async path - if not allow_async_output_proc: - self._process_model_outputs(ctx=ctx) - - # Log stats. - self.do_log_stats(scheduler_outputs, outputs) - - # Tracing - self.do_tracing(scheduler_outputs) - else: - # Multi-step case - return ctx.request_outputs - - if not self.has_unfinished_requests(): - # Drain async postprocessor (if exists) - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - assert len(ctx.output_queue) == 0 - - # Stop the execute model loop in parallel workers until there are - # more requests to process. This avoids waiting indefinitely in - # torch.distributed ops which may otherwise time out, and unblocks - # the RPC thread in the workers so that they can process any other - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() - - return ctx.request_outputs - - def _abort_and_cache_schedule( - self, request_id: str, virtual_engine: int, - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, - allow_async_output_proc: bool) -> None: - """Aborts a single request, and caches the scheduler outputs minus that - request. This allows the next step to continue processing the remaining - requests without having to re-run the scheduler.""" - - # Abort the request and remove its sequence group from the current - # schedule - self.abort_request(request_id) - for i, metadata in enumerate(seq_group_metadata_list): - if metadata.request_id == request_id: - del seq_group_metadata_list[i] - break - for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): - if group.seq_group.request_id == request_id: - del scheduler_outputs.scheduled_seq_groups[i] - break - - # If there are still other sequence groups left in the schedule, cache - # them and flag the engine to reuse the schedule. - if len(seq_group_metadata_list) > 0: - self._skip_scheduling_next_step = True - # Reuse multi-step caching logic - self._cache_scheduler_outputs_for_multi_step( - virtual_engine=virtual_engine, - scheduler_outputs=scheduler_outputs, - seq_group_metadata_list=seq_group_metadata_list, - allow_async_output_proc=allow_async_output_proc) - - def _has_remaining_steps( - self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] - ) -> bool: - return False - - def _cache_scheduler_outputs_for_multi_step( - self, virtual_engine: int, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs, - allow_async_output_proc: bool) -> None: - co = self.cached_scheduler_outputs[virtual_engine] - - co.seq_group_metadata_list = seq_group_metadata_list - co.scheduler_outputs = scheduler_outputs - co.allow_async_output_proc = allow_async_output_proc - co.last_output = None - - def _update_cached_scheduler_output( - self, virtual_engine: int, - output: List[Optional[SamplerOutput]]) -> None: - if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 - and output[0] is not None): - last_output = output[-1] - assert last_output is not None - assert last_output.sampled_token_ids_cpu is not None - assert last_output.sampled_token_ids is None - assert last_output.sampled_token_probs is None - self.cached_scheduler_outputs[ - virtual_engine].last_output = last_output - - def _get_last_sampled_token_ids( - self, virtual_engine: int) -> Optional[torch.Tensor]: - return None - - def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} already exists.") - self.stat_loggers[logger_name] = logger - - def remove_logger(self, logger_name: str) -> None: - if not self.log_stats: - raise RuntimeError( - "Stat logging is disabled. Set `disable_log_stats=False` " - "argument to enable.") - if logger_name not in self.stat_loggers: - raise KeyError(f"Logger with name {logger_name} does not exist.") - del self.stat_loggers[logger_name] - - def do_log_stats(self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> None: - """Forced log when no requests active.""" - if self.log_stats: - stats = self._get_stats(scheduler_outputs, model_output, - finished_before, skip) - for logger in self.stat_loggers.values(): - logger.log(stats) - - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs], - model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None, - skip: Optional[List[int]] = None) -> Stats: - """Get Stats to be Logged to Prometheus. - - Args: - scheduler_outputs: Optional, used to populate metrics related to - the scheduled batch, - model_output: Optional, used to emit speculative decoding metrics - which are created by the workers. - finished_before: Optional, indices of sequences that were finished - before. These sequences will be ignored. - skip: Optional, indices of sequences that were preempted. These - sequences will be ignored. - """ - now = time.time() - - # System State - # Scheduler State - num_running_sys = sum( - len(scheduler.running) for scheduler in self.scheduler) - num_swapped_sys = sum( - len(scheduler.swapped) for scheduler in self.scheduler) - num_waiting_sys = sum( - len(scheduler.waiting) for scheduler in self.scheduler) - - # KV Cache Usage in % - num_total_gpu = self.cache_config.num_gpu_blocks - gpu_cache_usage_sys = 0. - if num_total_gpu: # Guard against both None and 0 - num_free_gpu = sum( - scheduler.block_manager.get_num_free_gpu_blocks() - for scheduler in self.scheduler) - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) - - num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage_sys = 0. - if num_total_cpu: # Guard against both None and 0 - num_free_cpu = sum( - scheduler.block_manager.get_num_free_cpu_blocks() - for scheduler in self.scheduler) - cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) - - # Prefix Cache Hit Rate. Note that we always use - # the cache hit rate of the first virtual engine. - cpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.CPU) - gpu_prefix_cache_hit_rate = self.scheduler[ - 0].get_prefix_cache_hit_rate(Device.GPU) - - # Exchange the uasge and cache hit stats between gpu and cpu when - # running on cpu because the cpu_worker.py intentionally reports the - # number of cpu blocks as gpu blocks in favor of cache management. - if self.device_config.device_type == "cpu": - num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu - gpu_cache_usage_sys, cpu_cache_usage_sys = ( - cpu_cache_usage_sys, - gpu_cache_usage_sys, - ) - gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = ( - cpu_prefix_cache_hit_rate, - gpu_prefix_cache_hit_rate, - ) - - # Iteration stats - num_prompt_tokens_iter = 0 - num_generation_tokens_iter = 0 - num_tokens_iter = 0 - time_to_first_tokens_iter: List[float] = [] - inter_token_latencies_iter: List[float] = [] - num_preemption_iter = (0 if scheduler_outputs is None else - scheduler_outputs.preempted) - - # Request stats - # Latency - time_e2e_requests: List[float] = [] - time_queue_requests: List[float] = [] - time_inference_requests: List[float] = [] - time_prefill_requests: List[float] = [] - time_decode_requests: List[float] = [] - # Metadata - num_prompt_tokens_requests: List[int] = [] - num_generation_tokens_requests: List[int] = [] - n_requests: List[int] = [] - max_num_generation_tokens_requests: List[int] = [] - max_tokens_requests: List[int] = [] - finished_reason_requests: List[str] = [] - - # LoRA requests - running_lora_adapters = dict( - collectionsCounter([ - running_request.lora_request.lora_name - for scheduler in self.scheduler - for running_request in scheduler.running - if running_request.lora_request - ])) - waiting_lora_adapters = dict( - collectionsCounter([ - waiting_request.lora_request.lora_name - for scheduler in self.scheduler - for waiting_request in scheduler.waiting - if waiting_request.lora_request - ])) - max_lora_stat = "0" - if self.lora_config: - max_lora_stat = str(self.lora_config.max_loras) - - # NOTE: This loop assumes prefill seq_groups are before - # decode seq_groups in scheduled_seq_groups. - if scheduler_outputs is not None: - # For async postprocessor, already finished sequences need to be - # not counted (to avoid double counting) - actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - - num_generation_tokens_from_prefill_groups = 0 - # NOTE: if scheduler_outputs.num_prefill_groups > 0 and - # the len of scheduler_outputs.scheduled_seq_groups is != - # scheduler_outputs.num_prefill_groups, this means that - # chunked prefills have been detected. - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double logging when using async output proc - if finished_before and idx in finished_before: - actual_num_batched_tokens -= 1 - continue - - # Currently, skip == preempted sequences, so we need to skip - # their log stats - if skip and idx in skip: - continue - - group_was_prefill = idx < scheduler_outputs.num_prefill_groups - seq_group = scheduled_seq_group.seq_group - - # NOTE: a seq_group that completed all of its prefill tokens - # in the last iteration will have seq_group.is_prefill() = False - # with group_was_prefill = True - if group_was_prefill: - # Number of prompt tokens. - num_prompt_tokens_iter += ( - scheduled_seq_group.token_chunk_size) - - # If the seq_group just finished the prefill state - # get TTFT. - if not seq_group.is_prefill(): - latency = seq_group.get_last_token_latency() - time_to_first_tokens_iter.append(latency) - - # One generation token per finished prefill. - num_generation_tokens_from_prefill_groups += ( - seq_group.num_seqs()) - else: - # ITLs - latency = seq_group.get_last_token_latency() - inter_token_latencies_iter.append(latency) - if seq_group.state.current_step == 0: - # For async_output_proc, the do_log_stats() - # is called following init_multi_step(), which - # sets the current_step to zero. - actual_num_batched_tokens +=\ - seq_group.state.num_steps - 1 - else: - actual_num_batched_tokens +=\ - seq_group.state.current_step - 1 - - # Because of chunked prefill, we can have a single sequence - # group that does multiple prompt_runs. To prevent logging - # the same metadata more than once per request, we standardize - # on logging request level information for finished requests, - # which can only happen once. - if seq_group.is_finished(): - # Latency timings - time_e2e_requests.append(now - - seq_group.metrics.arrival_time) - if (seq_group.metrics.first_scheduled_time is not None and - seq_group.metrics.first_token_time is not None): - time_queue_requests.append( - seq_group.metrics.first_scheduled_time - - seq_group.metrics.arrival_time) - time_prefill_requests.append( - seq_group.metrics.first_token_time - - seq_group.metrics.first_scheduled_time) - time_decode_requests.append( - now - seq_group.metrics.first_token_time) - time_inference_requests.append( - now - seq_group.metrics.first_scheduled_time) - # Metadata - num_prompt_tokens_requests.append( - len(seq_group.prompt_token_ids)) - num_generation_tokens_requests.extend([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ]) - max_num_generation_tokens_requests.append( - max(seq.get_output_len() - for seq in seq_group.get_seqs())) - if seq_group.sampling_params is not None: - n_requests.append(seq_group.sampling_params.n) - max_tokens_requests.append( - seq_group.sampling_params.max_tokens) - finished_reason_requests.extend([ - SequenceStatus.get_finished_reason(seq.status) - for seq in seq_group.get_finished_seqs() - ]) - - # Number of generation tokens. - # num_batched_tokens equals the number of prompt_tokens plus the - # number of decode_tokens in a single iteration. So, - # num_generation_tokens = num_batched_tokens - num_prompt_tokens - # + num_generation_tokens_from_prefill_groups (since we generate - # one token on prefills on iters where the prefill finishes). - num_generation_tokens_iter = ( - actual_num_batched_tokens - num_prompt_tokens_iter + - num_generation_tokens_from_prefill_groups) - num_tokens_iter = (num_generation_tokens_iter + - num_prompt_tokens_iter) - - return Stats( - now=now, - # System stats - # Scheduler State - num_running_sys=num_running_sys, - num_swapped_sys=num_swapped_sys, - num_waiting_sys=num_waiting_sys, - # KV Cache Usage in % - gpu_cache_usage_sys=gpu_cache_usage_sys, - cpu_cache_usage_sys=cpu_cache_usage_sys, - # Prefix Cache Hit Rate - cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, - gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, - - # Iteration stats - num_prompt_tokens_iter=num_prompt_tokens_iter, - num_generation_tokens_iter=num_generation_tokens_iter, - num_tokens_iter=num_tokens_iter, - time_to_first_tokens_iter=time_to_first_tokens_iter, - inter_token_latencies_iter=inter_token_latencies_iter, - num_preemption_iter=num_preemption_iter, - - # Request stats - # Latency - time_e2e_requests=time_e2e_requests, - time_queue_requests=time_queue_requests, - time_inference_requests=time_inference_requests, - time_prefill_requests=time_prefill_requests, - time_decode_requests=time_decode_requests, - # Metadata - num_prompt_tokens_requests=num_prompt_tokens_requests, - num_generation_tokens_requests=num_generation_tokens_requests, - max_num_generation_tokens_requests= - max_num_generation_tokens_requests, - n_requests=n_requests, - max_tokens_requests=max_tokens_requests, - finished_reason_requests=finished_reason_requests, - max_lora=str(max_lora_stat), - waiting_lora_adapters=list(waiting_lora_adapters.keys()), - running_lora_adapters=list(running_lora_adapters.keys())) - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_executor.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_executor.remove_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_executor.list_loras() - - def pin_lora(self, lora_id: int) -> bool: - return self.model_executor.pin_lora(lora_id) - - def start_profile(self) -> None: - self.model_executor.start_profile() - - def stop_profile(self) -> None: - self.model_executor.stop_profile() - - def sleep(self, level: int = 1) -> None: - assert self.vllm_config.model_config.enable_sleep_mode, ( - "Sleep mode is not enabled in the model config") - self.model_executor.sleep(level=level) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - assert self.vllm_config.model_config.enable_sleep_mode, ( - "Sleep mode is not enabled in the model config") - self.model_executor.wake_up(tags) - - def is_sleeping(self) -> bool: - return self.model_executor.is_sleeping - - def check_health(self) -> None: - self.model_executor.check_health() - - def is_tracing_enabled(self) -> bool: - return self.tracer is not None - - def do_tracing(self, - scheduler_outputs: SchedulerOutputs, - finished_before: Optional[List[int]] = None) -> None: - if self.tracer is None: - return - - for idx, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - # Skip double tracing when using async output proc - if finished_before and idx in finished_before: - continue - - seq_group = scheduled_seq_group.seq_group - if seq_group.is_finished(): - self.create_trace_span(seq_group) - - def create_trace_span(self, seq_group: SequenceGroup) -> None: - if self.tracer is None or seq_group.sampling_params is None: - return - arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) - - trace_context = extract_trace_context(seq_group.trace_headers) - - with self.tracer.start_as_current_span( - "llm_request", - kind=SpanKind.SERVER, - context=trace_context, - start_time=arrival_time_nano_seconds) as seq_span: - metrics = seq_group.metrics - - # Handle potential None values for cancelled/aborted requests - ttft = (metrics.first_token_time - metrics.arrival_time - if metrics.first_token_time is not None else None) - - e2e_time = (metrics.finished_time - metrics.arrival_time - if metrics.finished_time is not None else None) - - seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL, - self.model_config.model) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, - seq_group.request_id) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, - seq_group.sampling_params.temperature) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, - seq_group.sampling_params.top_p) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, - seq_group.sampling_params.max_tokens) - seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, - seq_group.sampling_params.n) - seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES, - seq_group.num_seqs()) - seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, - len(seq_group.prompt_token_ids)) - seq_span.set_attribute( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, - sum([ - seq.get_output_len() - for seq in seq_group.get_finished_seqs() - ])) - - # Only set timing attributes if the values are available - if metrics.time_in_queue is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, - metrics.time_in_queue) - if ttft is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft) - if e2e_time is not None: - seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, - e2e_time) - if metrics.scheduler_time is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER, - metrics.scheduler_time) - if metrics.model_forward_time is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD, - metrics.model_forward_time / 1000.0) - if metrics.model_execute_time is not None: - seq_span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, - metrics.model_execute_time) - - def _validate_model_inputs(self, inputs: ProcessorInputs): - encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) - - if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, prompt_type="encoder") - - self._validate_model_input(decoder_inputs, prompt_type="decoder") - - def _validate_model_input( - self, - prompt_inputs: SingletonInputs, - *, - prompt_type: Literal["encoder", "decoder"], - ): - model_config = self.model_config - tokenizer = self.tokenizer - - prompt_ids = prompt_inputs.get("prompt_token_ids", []) - if not prompt_ids: - if prompt_type == "encoder" and model_config.is_multimodal_model: - pass # Mllama may have empty encoder inputs for text-only data - elif prompt_inputs["type"] == "embeds": - pass - else: - raise ValueError(f"The {prompt_type} prompt cannot be empty") - - if tokenizer is not None: - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") - - max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) > max_prompt_len: - if prompt_type == "encoder" and model_config.is_multimodal_model: - mm_registry = self.input_preprocessor.mm_registry - mm_processor = mm_registry.create_processor( - model_config, - tokenizer=tokenizer or object(), # Dummy if no tokenizer - ) - assert isinstance(mm_processor, EncDecMultiModalProcessor) - - if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper - - if model_config.is_multimodal_model: - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") - else: - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens.") - - raise ValueError( - f"The {prompt_type} prompt (length {len(prompt_ids)}) is " - f"longer than the maximum model length of {max_prompt_len}. " - f"{suggestion}") - - # TODO: Find out how many placeholder tokens are there so we can - # check that chunked prefill does not truncate them - # max_batch_len = self.scheduler_config.max_num_batched_tokens - - def _build_logits_processors( - self, sampling_params: SamplingParams, - lora_request: Optional[LoRARequest]) -> SamplingParams: - """Constructs logits processors based on the logits_bias, and - allowed_token_ids fields in sampling_params. Deletes those fields and - adds the constructed logits processors to the logits_processors field. - Returns the modified sampling params.""" - - logits_processors = [] - - if (sampling_params.logit_bias or sampling_params.allowed_token_ids): - tokenizer = self.get_tokenizer() - - processors = get_openai_logits_processors( - logit_bias=sampling_params.logit_bias, - allowed_token_ids=sampling_params.allowed_token_ids, - tokenizer=tokenizer) - logits_processors.extend(processors) - - # Unset so these don't get passed down to the model - sampling_params.logit_bias = None - sampling_params.allowed_token_ids = None - - if len(sampling_params.bad_words) > 0: - tokenizer = self.get_tokenizer() - processors = get_bad_words_logits_processors( - bad_words=sampling_params.bad_words, tokenizer=tokenizer) - logits_processors.extend(processors) - - if logits_processors: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = logits_processors - else: - sampling_params.logits_processors.extend(logits_processors) - - return sampling_params - - def collective_rpc(self, - method: Union[str, Callable[[WorkerBase], _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.model_executor.collective_rpc(method, timeout, args, - kwargs) - - def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - return self.collective_rpc("apply_model", args=(func, )) - - -if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - LLMEngine = V1LLMEngine # type: ignore +LLMEngine = V1LLMEngine # type: ignore diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py deleted file mode 100644 index 587a9221e32c8..0000000000000 --- a/vllm/engine/output_processor/interfaces.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import List - -from vllm.config import SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.sequence import SequenceGroup, SequenceGroupOutput -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.utils import Counter - - -class SequenceGroupOutputProcessor(ABC): - """Interface for logic that processes new token ids in sequence groups, - managing detokenization, stop checking, and freeing/forking sequences with - the scheduler. - - This is highly coupled with the LLMEngine and should be seen as an extension - of it. The logic is separated to simplify the LLMEngine class and allow - separate implementations for single-step decoding (which supports beam - search sequence forking) and multi-step decoding (which does not support - beam search, but does support speculative decoding). - """ - - @staticmethod - def create_output_processor( - scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - stop_checker: "StopChecker", - ): - """Create an output processor. - - Multi-step scheduling is no longer supported. Always return a - single-step output processor. - """ - from vllm.engine.output_processor.single_step import ( - SingleStepOutputProcessor) - return SingleStepOutputProcessor(scheduler_config, detokenizer, - scheduler, seq_counter, stop_checker) - - @abstractmethod - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput], - is_async: bool) -> None: - """Process new token ids for the sequence group. Handles logic such as - detokenization, stop checking, and freeing/forking sequences in the - scheduler. - """ - pass - - @abstractmethod - def process_prompt_logprob(self, seq_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: - """Update prompt logprobs received from outputs to seq_group.""" - pass diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py deleted file mode 100644 index dbf6a371d050a..0000000000000 --- a/vllm/engine/output_processor/single_step.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List - -from vllm.config import SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.interfaces import ( - SequenceGroupOutputProcessor) -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.logger import init_logger -from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, - SequenceGroupOutput) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.utils import Counter - -logger = init_logger(__name__) - - -def single_step_process_prompt_logprob( - sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, - output: CompletionSequenceGroupOutput) -> None: - """Process prompt logprobs associated with the - [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step. - - Do nothing if the output has no prompt logprobs. - - Account for the fact that transformers do not compute first-token logprobs. - - Args: - sg_output_proc: - [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor] - instance - seq_group: the output is associated with this - [`SequenceGroup`][vllm.sequence.SequenceGroup] - output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] - for a single scheduler step - """ - prompt_logprobs = output.prompt_logprobs - - # If this is the first (or only) "chunk" of the prefill, we need - # to prepend None to the list of prompt logprobs. The reason for this - # is that for N prompt tokens, the Sampler will generate N-1 total - # prompt logprobs during prefill since the token at idx 0 will not - # have a logprob associated with it. - if prompt_logprobs is not None: - if not seq_group.prompt_logprobs: - prompt_logprobs = [None] + prompt_logprobs - seq_group.prompt_logprobs = [] - - assert hasattr(sg_output_proc, 'detokenizer') - if (seq_group.sampling_params.detokenize - and sg_output_proc.detokenizer): - sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( - seq_group, - prompt_logprobs, - position_offset=len(seq_group.prompt_logprobs)) - - seq_group.prompt_logprobs.extend(prompt_logprobs) - - -class SingleStepOutputProcessor(SequenceGroupOutputProcessor): - """SequenceGroupOutputProcessor which handles "output processing" logic, - which happens after the model returns generated token ids and before - scheduling of the next batch. Output processing logic includes - detokenization, and determining if a sequence is finished (e.g. via max len - or eos token). - - The SingleStepOutputProcessor is specialized to the case where the model - emits at most a single token per invocation, which precludes configurations - such as speculative decoding or multi-step decoding. This enables beam - search sampling, which requires forking/finishing/freeing sequences in a way - that is currently difficult to schedule multiple steps ahead of time. - """ - - def __init__(self, scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, scheduler: List[Scheduler], - seq_counter: Counter, stop_checker: StopChecker): - self.scheduler_config = scheduler_config - self.detokenizer = detokenizer - self.scheduler = scheduler - self.seq_counter = seq_counter - self.stop_checker = stop_checker - - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput], - is_async: bool) -> None: - """Append all new tokens to sequences in the sequence group. Fork any - surviving beam candidates; free any unsurviving ones. - - Invokes detokenizer to detokenize new tokens, and also marks sequences - as finished if they meet stop conditions. - - is_async - Indicates whether this postprocessor runs in - parallel with the GPU forward pass and is processing - tokens from the previous step. If this is true, then - no tokens need to be appended since it is already done - externally (before the next schedule() call) - """ - assert (len(outputs) == 1 - ), f"{type(self)} does not support multiple outputs per step" - return self._process_sequence_group_outputs(sequence_group, outputs[0], - is_async) - - def process_prompt_logprob(self, seq_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: - """Process prompt logprobs associated with one step of a single-step- - scheduled computation. - - Args: - seq_group: the output is associated with this - [`SequenceGroup`][vllm.sequence.SequenceGroup] - outputs: the - [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] - for a single scheduler step - """ - assert len(outputs) == 1, "Single step should only have 1 output." - output = outputs[0] - assert isinstance(output, CompletionSequenceGroupOutput) - single_step_process_prompt_logprob(self, seq_group, output) - - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput, - is_async: bool) -> None: - sampling_params = seq_group.sampling_params - - sample = outputs.samples[0] - seq = seq_group.first_seq - if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs, - sample.output_embed) - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py deleted file mode 100644 index 0916f1c918c85..0000000000000 --- a/vllm/engine/output_processor/stop_checker.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Tuple - -from vllm.lora.request import LoRARequest -from vllm.reasoning import ReasoningParser -from vllm.sampling_params import SamplingParams -from vllm.sequence import Sequence, SequenceStatus - - -class StopChecker: - """LLMEngine helper class which separates out the logic involving stop - checking. This checks things such as: whether the eos token was emitted, - whether the max_tokens has been consumed, whether a stop string has been - emitted, or if we have exceeded the max model len. - """ - - def __init__( - self, - max_model_len: int, - reasoner: Optional[ReasoningParser] = None, - ): - # Do not use it directly, but use `self._get_max_model_len`. - self._max_model_len = max_model_len - self.reasoner = reasoner - - def _get_max_model_len(self, lora_req: Optional[LoRARequest]): - if lora_req and lora_req.long_lora_max_len: - return lora_req.long_lora_max_len - else: - return self._max_model_len - - def maybe_stop_sequence( - self, - seq: Sequence, - new_char_count: int, - sampling_params: SamplingParams, - lora_req: Optional[LoRARequest] = None, - ) -> None: - """Stop the finished sequences. - - new_char_count is the number of chars added to the - sequence's output text for the newly generated token - """ - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - # Remove the last EOS token unless explicitly specified - # This prevents unintended exposure of the EOS token - if new_char_count and ( - not sampling_params.include_stop_str_in_output): - seq.output_text = seq.output_text[:-new_char_count] - seq.status = SequenceStatus.FINISHED_STOPPED - return - - # Skip stop string/token checks if in reasoning content generation - if self.reasoner is not None and \ - not self.reasoner.is_reasoning_end(seq.get_token_ids()): - return - - # Check if a stop token was encountered. - # This assumes a single token produced per step. - last_token_id = seq.get_last_token_id() - if last_token_id in (sampling_params.stop_token_ids or ()): - if new_char_count and ( - not sampling_params.include_stop_str_in_output): - # Remove last token - seq.output_text = seq.output_text[:-new_char_count] - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if any stop strings are matched. - stop = self.check_stop_strings( - seq.output_text, new_char_count, sampling_params.stop, - sampling_params.include_stop_str_in_output) - if stop is not None: - stop_str, truncate_to = stop - if truncate_to != -1: - seq.output_text = seq.output_text[:truncate_to] - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - - # Check if the sequence has reached max_model_len. - if seq.get_len() >= self._get_max_model_len(lora_req): - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - @staticmethod - def check_stop_strings( - output_text: str, - new_char_count: int, - stop: List[str], - include_in_output: bool, - ) -> Optional[Tuple[str, int]]: - """Check if any stop strings are matched and truncate sequence - output text accordingly. - - Returns tuple (stop_string, offset) if matched or else None. - - Where stop_string is the matched stop string and offset is the - length to which output_text should be truncated, or -1 for no - truncation. - """ - if not new_char_count or not stop: - return None - - for stop_str in stop: - stop_string_len = len(stop_str) - # Avoid searching already-searched text. - stop_index = output_text.find(stop_str, - 1 - new_char_count - stop_string_len) - if stop_index == -1: - continue - - if include_in_output: - # Truncate to end of stop string. - stop_index += stop_string_len - if stop_index >= len(output_text): - # No truncation required. - return stop_str, -1 - - # Truncate the output text to either the beginning - # or end of the stop string. - return stop_str, stop_index - return None diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index c345f17e6614f..e828ac04364ff 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -7,13 +7,11 @@ from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import ModelConfig, VllmConfig -from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors.interface import IOProcessor from vllm.pooling_params import PoolingParams @@ -266,11 +264,7 @@ class EngineClient(ABC): ... @abstractmethod - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[list[SamplerOutput]] = None, - ) -> None: + async def do_log_stats(self) -> None: ... @abstractmethod diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c2c0ad74ef431..df49119d86420 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -421,6 +421,51 @@ def resolve_mistral_chat_template( return None +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]() +""" +Used in `_try_get_processor_chat_template` to avoid calling +`cached_get_processor` again if the processor fails to be loaded. + +This is needed because `lru_cache` does not cache when an exception happens. +""" + + +def _try_get_processor_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model_config: ModelConfig, +) -> Optional[str]: + cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) + if cache_key in _PROCESSOR_CHAT_TEMPLATES: + return _PROCESSOR_CHAT_TEMPLATES[cache_key] + + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), + trust_remote_code=model_config.trust_remote_code, + ) + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and (chat_template := processor.chat_template) is not None + ): + _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template + return chat_template + except Exception: + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + _PROCESSOR_CHAT_TEMPLATES[cache_key] = None + return None + + def resolve_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], @@ -434,28 +479,10 @@ def resolve_hf_chat_template( # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: - try: - processor = cached_get_processor( - tokenizer.name_or_path, - processor_cls=( - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ProcessorMixin, - ), - trust_remote_code=model_config.trust_remote_code, - ) - if ( - isinstance(processor, ProcessorMixin) - and hasattr(processor, "chat_template") - and processor.chat_template is not None - ): - return processor.chat_template - except Exception: - logger.debug( - "Failed to load AutoProcessor chat template for %s", - tokenizer.name_or_path, - exc_info=True, - ) # noqa: E501 + chat_template = _try_get_processor_chat_template(tokenizer, + model_config) + if chat_template is not None: + return chat_template # 3rd priority: AutoTokenizer chat template try: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f2282c40f7073..092d3f276d1c5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -11,7 +11,6 @@ from pydantic import ValidationError from tqdm.auto import tqdm from typing_extensions import TypeVar -import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, create_sort_beams_key_function) @@ -19,7 +18,6 @@ from vllm.config import (CompilationConfig, ModelDType, StructuredOutputsConfig, TokenizerMode, is_init_field) from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides, PoolerConfig, RunnerOption) -from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption, apply_hf_chat_template, @@ -54,6 +52,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, Device, as_iter, is_list_of +from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -138,8 +137,6 @@ class LLM: back to the eager mode. disable_custom_all_reduce: See [ParallelConfig][vllm.config.ParallelConfig]. - disable_async_output_proc: Disable async output processing. - This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -189,7 +186,6 @@ class LLM: enforce_eager: bool = False, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, - disable_async_output_proc: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, @@ -287,7 +283,6 @@ class LLM: enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - disable_async_output_proc=disable_async_output_proc, hf_token=hf_token, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, @@ -309,11 +304,7 @@ class LLM: self.request_counter = Counter() self.default_sampling_params: Union[dict[str, Any], None] = None - if envs.VLLM_USE_V1: - supported_tasks = self.llm_engine \ - .get_supported_tasks() # type: ignore - else: - supported_tasks = self.llm_engine.model_config.supported_tasks + supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore logger.info("Supported_tasks: %s", supported_tasks) @@ -1473,8 +1464,6 @@ class LLM: Note: This method is only available with the V1 LLM engine. """ - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine - assert isinstance(self.llm_engine, V1LLMEngine) return self.llm_engine.get_metrics() def _validate_and_add_requests( diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index b75b94ad0acc2..fd4b992c3821b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -15,10 +15,10 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.tasks import SupportedTask from vllm.utils import make_async +from vllm.v1.outputs import SamplerOutput from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py deleted file mode 100644 index 136dca54e6e52..0000000000000 --- a/vllm/executor/mp_distributed_executor.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -from typing import Any, Callable, List, Optional, Union - -import cloudpickle - -from vllm.executor.executor_base import DistributedExecutorBase -from vllm.executor.multiproc_worker_utils import ( - ProcessWorkerWrapper, ResultHandler, WorkerMonitor, - set_multiprocessing_worker_envs) -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, - get_distributed_init_method, get_ip, get_open_port, - make_async, run_method, update_environment_variables) -from vllm.worker.worker_base import WorkerWrapperBase - -logger = init_logger(__name__) - - -class MultiprocessingDistributedExecutor(DistributedExecutorBase): - """Python multiprocessing-based distributed executor""" - - uses_ray: bool = False - - def _check_cuda(self) -> None: - """Check that the number of GPUs is sufficient for the parallel - configuration. Separate from _init_executor to reduce the number of - indented blocks. - """ - parallel_config = self.parallel_config - world_size = parallel_config.world_size - tensor_parallel_size = parallel_config.tensor_parallel_size - - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - if tensor_parallel_size > cuda_device_count: - raise RuntimeError( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - if world_size > cuda_device_count: - raise RuntimeError( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - - def _init_executor(self) -> None: - - from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - self._check_cuda() - - # Create the parallel GPU workers. - world_size = self.parallel_config.world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - - # Set multiprocessing envs that are common to V0 and V1 - set_multiprocessing_worker_envs(self.parallel_config) - - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - self.workers: List[ProcessWorkerWrapper] = [] - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: List[ProcessWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: List[ProcessWorkerWrapper] = [] - - if world_size == 1: - self.worker_monitor = None - else: - result_handler = ResultHandler() - for rank in range(1, world_size): - worker = ProcessWorkerWrapper(result_handler, - WorkerWrapperBase, - self.vllm_config, rank) - self.workers.append(worker) - if rank % tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) - - self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() - self.worker_monitor.start() - - # Set up signal handlers to shut down the executor cleanly - # sometimes gc does not work well - - self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) - - all_kwargs = [] - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - for i in range(world_size): - local_rank = i - rank = i - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=(not self.parallel_config) - or (rank % self.parallel_config.tensor_parallel_size == 0), - ) - all_kwargs.append(kwargs) - self._run_workers("init_worker", all_kwargs) - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - self.driver_exec_model = make_async(self.driver_worker.execute_model) - self.pp_locks: Optional[List[asyncio.Lock]] = None - - def shutdown(self): - if (worker_monitor := getattr(self, "worker_monitor", - None)) is not None: - worker_monitor.close() - - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - return self.driver_worker.execute_model(execute_model_req) - - def _run_workers( - self, - method: Union[str, Callable], - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> List[Any]: - """Runs the given method on all workers. - - Args: - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - """ - if isinstance(method, str): - sent_method = method - else: - sent_method = cloudpickle.dumps(method) - del method - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - if async_run_tensor_parallel_workers_only: - # Run only non-driver workers and just return futures. - return [ - worker.execute_method(sent_method, *args, **kwargs) - for worker in self.non_driver_workers - ] - - # Start all remote workers first. - worker_outputs = [ - worker.execute_method(sent_method, *args, **kwargs) - for worker in self.workers - ] - - driver_worker_output = run_method(self.driver_worker, sent_method, - args, kwargs) - - # Get the results of the workers. - return [driver_worker_output - ] + [output.get() for output in worker_outputs] - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - if self.worker_monitor is not None and not self.worker_monitor.is_alive( - ): - raise RuntimeError("Worker processes are not running") - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - for result in parallel_worker_tasks: - result.get() - - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - if not self.tp_driver_workers: - return await self.driver_exec_model(execute_model_req) - - if self.pp_locks is None: - # This locks each pipeline parallel stage so multiple virtual - # engines can't execute on the same stage at the same time - # We create the locks here to avoid creating them in the constructor - # which uses a different asyncio loop. - self.pp_locks = [ - asyncio.Lock() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - tasks = [ - asyncio.create_task( - _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], - execute_model_req)) - ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, - start=1): - tasks.append( - asyncio.create_task( - _run_task_with_lock(driver_worker.execute_method_async, - self.pp_locks[pp_rank], - "execute_model", execute_model_req))) - results = await asyncio.gather(*tasks) - - # Only the last PP stage has the final results. - return results[-1] - - async def _start_worker_execution_loop(self): - coros = [ - worker.execute_method_async("start_worker_execution_loop") - for worker in self.non_driver_workers - ] - return await asyncio.gather(*coros) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py deleted file mode 100644 index 48b3479ed7997..0000000000000 --- a/vllm/executor/multiproc_worker_utils.py +++ /dev/null @@ -1,279 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -import threading -import uuid -from dataclasses import dataclass -from multiprocessing import Queue -from multiprocessing.connection import wait -from multiprocessing.process import BaseProcess -from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union - -import torch - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.utils import (_maybe_force_spawn, decorate_logs, get_mp_context, - run_method) - -logger = init_logger(__name__) - -T = TypeVar('T') - -_TERMINATE = "TERMINATE" # sentinel - -JOIN_TIMEOUT_S = 2 - - -@dataclass -class Result(Generic[T]): - """Result of task dispatched to worker""" - - task_id: uuid.UUID - value: Optional[T] = None - exception: Optional[BaseException] = None - - -class ResultFuture(threading.Event, Generic[T]): - """Synchronous future for non-async case""" - - def __init__(self): - super().__init__() - self.result: Optional[Result[T]] = None - - def set_result(self, result: Result[T]): - self.result = result - self.set() - - def get(self) -> T: - self.wait() - assert self.result is not None - if self.result.exception is not None: - raise self.result.exception - return self.result.value # type: ignore[return-value] - - -def _set_future_result(future: Union[ResultFuture, asyncio.Future], - result: Result): - if isinstance(future, ResultFuture): - future.set_result(result) - return - loop = future.get_loop() - if not loop.is_closed(): - if result.exception is not None: - loop.call_soon_threadsafe(future.set_exception, result.exception) - else: - loop.call_soon_threadsafe(future.set_result, result.value) - - -class ResultHandler(threading.Thread): - """Handle results from all workers (in background thread)""" - - def __init__(self) -> None: - super().__init__(daemon=True) - self.result_queue = get_mp_context().Queue() - self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} - - def run(self): - for result in iter(self.result_queue.get, _TERMINATE): - future = self.tasks.pop(result.task_id) - _set_future_result(future, result) - # Ensure that all waiters will receive an exception - for task_id, future in self.tasks.items(): - _set_future_result( - future, - Result(task_id=task_id, - exception=ChildProcessError("worker died"))) - - def close(self): - self.result_queue.put(_TERMINATE) - - -class WorkerMonitor(threading.Thread): - """Monitor worker status (in background thread)""" - - def __init__(self, workers: List['ProcessWorkerWrapper'], - result_handler: ResultHandler): - super().__init__(daemon=True) - self.workers = workers - self.result_handler = result_handler - self._close = False - - def run(self) -> None: - # Blocks until any worker exits - dead_sentinels = wait([w.process.sentinel for w in self.workers]) - if not self._close: - self._close = True - - # Kill / cleanup all workers - for worker in self.workers: - process = worker.process - if process.sentinel in dead_sentinels: - process.join(JOIN_TIMEOUT_S) - if process.exitcode is not None and process.exitcode != 0: - logger.error("Worker %s pid %s died, exit code: %s", - process.name, process.pid, process.exitcode) - # Cleanup any remaining workers - if logger: - logger.info("Killing local vLLM worker processes") - for worker in self.workers: - worker.kill_worker() - # Must be done after worker task queues are all closed - self.result_handler.close() - - for worker in self.workers: - worker.process.join(JOIN_TIMEOUT_S) - - def close(self): - if self._close: - return - self._close = True - logger.info("Terminating local vLLM worker processes") - for worker in self.workers: - worker.terminate_worker() - # Must be done after worker task queues are all closed - self.result_handler.close() - - -class ProcessWorkerWrapper: - """Local process wrapper for vllm.worker.Worker, - for handling single-node multi-GPU tensor parallel.""" - - def __init__(self, result_handler: ResultHandler, - worker_factory: Callable[[VllmConfig, int], Any], - vllm_config: VllmConfig, rank: int) -> None: - self.mp = get_mp_context() - self._task_queue = self.mp.Queue() - self.result_queue = result_handler.result_queue - self.tasks = result_handler.tasks - self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined] - target=_run_worker_process, - name="VllmWorkerProcess", - kwargs=dict( - worker_factory=worker_factory, - task_queue=self._task_queue, - result_queue=self.result_queue, - vllm_config=vllm_config, - rank=rank, - ), - daemon=True) - - self.process.start() - - def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], - method: Union[str, bytes], args, kwargs): - task_id = uuid.uuid4() - self.tasks[task_id] = future - try: - self._task_queue.put((task_id, method, args, kwargs)) - except SystemExit: - raise - except BaseException as e: - del self.tasks[task_id] - raise ChildProcessError("worker died") from e - - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - future: ResultFuture = ResultFuture() - self._enqueue_task(future, method, args, kwargs) - return future - - async def execute_method_async(self, method: Union[str, bytes], *args, - **kwargs): - future = asyncio.get_running_loop().create_future() - self._enqueue_task(future, method, args, kwargs) - return await future - - def terminate_worker(self): - try: - self._task_queue.put(_TERMINATE) - except ValueError: - self.process.kill() - self._task_queue.close() - - def kill_worker(self): - self._task_queue.close() - self.process.kill() - - -def _run_worker_process( - worker_factory: Callable[[VllmConfig, int], Any], - task_queue: Queue, - result_queue: Queue, - vllm_config: VllmConfig, - rank: int, -) -> None: - """Worker process event loop""" - - # Add process-specific prefix to stdout and stderr - process_name = get_mp_context().current_process().name - decorate_logs(process_name) - - # Initialize worker - worker = worker_factory(vllm_config, rank) - del worker_factory - - # Accept tasks from the engine in task_queue - # and return task output in result_queue - logger.info("Worker ready; awaiting tasks") - try: - for items in iter(task_queue.get, _TERMINATE): - output = None - exception = None - task_id, method, args, kwargs = items - try: - output = run_method(worker, method, args, kwargs) - except SystemExit: - raise - except KeyboardInterrupt: - break - except BaseException as e: - logger.exception( - "Exception in worker %s while processing method %s.", - process_name, method) - exception = e - result_queue.put( - Result(task_id=task_id, value=output, exception=exception)) - except KeyboardInterrupt: - pass - except Exception: - logger.exception("Worker failed") - - # Flush TunableOp results when TunableOp is enabled and - # online (in situ) tuning is enabled. - # Offline tuning API (record_untuned_is_enabled()) only - # available in PyTorch 2.6 or later. - if torch.cuda.is_available(): - import torch.cuda.tunable as tunable - if (tunable.is_enabled() and tunable.tuning_is_enabled() - and not tunable.record_untuned_is_enabled()): - tunable.write_file() - - logger.info("Worker exiting") - - -def set_multiprocessing_worker_envs(parallel_config): - """ Set up environment variables that should be used when there are workers - in a multiprocessing environment. This should be called by the parent - process before worker processes are created""" - - _maybe_force_spawn() - - # Configure thread parallelism if OMP_NUM_THREADS isn't set - # - # Helps to avoid CPU contention. The default of spawning a thread per - # core combined with multiprocessing for each GPU can have a negative - # impact on performance. The contention is amplified when running in a - # container where CPU limits can cause throttling. - default_omp_num_threads = 1 - if "OMP_NUM_THREADS" not in os.environ and ( - current_parallelism := - torch.get_num_threads()) > default_omp_num_threads: - logger.warning( - "Reducing Torch parallelism from %d threads to %d to avoid " - "unnecessary CPU contention. Set OMP_NUM_THREADS in the " - "external environment to tune this value as needed.", - current_parallelism, default_omp_num_threads) - os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) - torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 78d0ee6c1e3fc..84747575b4960 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, ray) from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.v1.outputs import SamplerOutput if ray is not None: from ray.actor import ActorHandle diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 3b566e88a9ec2..7a753d608a43a 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -137,10 +137,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model. """ - assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ - ("ExecutorWithExternalLauncher needs deterministic " - "execution, so it" - "does not support delay_factor in scheduling") if envs.VLLM_USE_V1: assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ ("To get deterministic execution in V1, " diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index e9db2a0dc13a8..46f49aaa013da 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import (DummyData, InputContext, InputProcessingContext, - InputRegistry) - -INPUT_REGISTRY = InputRegistry() -""" -The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used -by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the -target model. -""" +from .registry import InputContext, InputProcessingContext __all__ = [ "DataPrompt", @@ -36,9 +28,6 @@ __all__ = [ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", - "INPUT_REGISTRY", - "DummyData", "InputContext", "InputProcessingContext", - "InputRegistry", ] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f0b392e9767ae..b5316b6d0574c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch from transformers import BatchFeature, PretrainedConfig, ProcessorMixin @@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, - MultiModalRegistry) - from vllm.sequence import SequenceData from vllm.transformers_utils.tokenizer import AnyTokenizer else: ModelConfig = Any - MultiModalDataDict = Any - MultiModalPlaceholderDict = Any - MultiModalRegistry = Any - SequenceData = Any AnyTokenizer = Any _T = TypeVar("_T") @@ -191,61 +184,3 @@ class InputProcessingContext(InputContext): f"on data={data} with kwargs={allowed_kwargs}") raise ValueError(msg) from exc - - -class DummyData(NamedTuple): - """ - Dummy data used for profiling. - - Note: This is only used in V0. - """ - - seq_data: SequenceData - multi_modal_data: Optional[MultiModalDataDict] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - - -class InputRegistry: - """ - Note: This is only used in V0. - """ - - def dummy_data_for_profiling( - self, - model_config: ModelConfig, - seq_len: int, - mm_registry: MultiModalRegistry, - is_encoder_data: bool = False, - ) -> DummyData: - """ - Create dummy data for profiling the memory usage of a model. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.multimodal.cache import processor_only_cache_from_config - from vllm.sequence import SequenceData - - if not model_config.is_multimodal_model: - seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) - return DummyData(seq_data=seq_data) - - cache = processor_only_cache_from_config(model_config, mm_registry) - - # Encoder dummy data does not contain multi-modal data - if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data(model_config, - seq_len, - cache=cache) - seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) - return DummyData(seq_data=seq_data) - - dec_data = mm_registry.get_decoder_dummy_data(model_config, - seq_len, - cache=cache) - - return DummyData( - seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data.get_data(), - multi_modal_placeholders=dec_data.multi_modal_placeholders, - ) diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py index cf690a89ae9bc..7202259ca21aa 100644 --- a/vllm/logging_utils/__init__.py +++ b/vllm/logging_utils/__init__.py @@ -2,7 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.logging_utils.formatter import NewLineFormatter +from vllm.logging_utils.log_time import logtime __all__ = [ "NewLineFormatter", + "logtime", ] diff --git a/vllm/logging_utils/log_time.py b/vllm/logging_utils/log_time.py new file mode 100644 index 0000000000000..013dd144beaf8 --- /dev/null +++ b/vllm/logging_utils/log_time.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Provides a timeslice logging decorator +""" + +import functools +import time + + +def logtime(logger, msg=None): + """ + Logs the execution time of the decorated function. + Always place it beneath other decorators. + """ + + def _inner(func): + + @functools.wraps(func) + def _wrapper(*args, **kwargs): + start = time.perf_counter() + result = func(*args, **kwargs) + elapsed = time.perf_counter() - start + + prefix = f"Function '{func.__module__}.{func.__qualname__}'" \ + if msg is None else msg + logger.debug("%s: Elapsed time %.7f secs", prefix, elapsed) + return result + + return _wrapper + + return _inner diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 55dfe8088c8f3..3c094cfdb553f 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -3,13 +3,9 @@ from vllm.model_executor.parameter import (BasevLLMParameter, PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingMetadataCache) from vllm.model_executor.utils import set_random_seed __all__ = [ - "SamplingMetadata", - "SamplingMetadataCache", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8a4ac214443eb..2110aa2769b93 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,26 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that compute logits from hidden_stats.""" -import inspect -from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform -_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None -if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: - _logits_processor_threadpool = ThreadPoolExecutor( - envs.VLLM_LOGITS_PROCESSOR_THREADS) - @CustomOp.register("logits_processor") class LogitsProcessor(CustomOp): @@ -58,17 +49,11 @@ class LogitsProcessor(CustomOp): self, lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, - prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None and prune_hidden_states: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: @@ -79,12 +64,6 @@ class LogitsProcessor(CustomOp): if self.scale != 1.0: logits *= self.scale - - # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - return logits def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: @@ -125,75 +104,3 @@ class LogitsProcessor(CustomOp): s += f", org_vocab_size={self.org_vocab_size}" s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" return s - - -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios - # (warmup, profile_run) we might not have selected_token_indices, - # so we skip pruning. - if sampling_metadata.selected_token_indices is not None: - return hidden_states.index_select( - 0, sampling_metadata.selected_token_indices) - else: - return hidden_states - - -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - found_logits_processors = False - logits_processed = 0 - logits_row_ids_and_logits_row_futures = [] - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): - logits_row = logits[logits_row_idx] - past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids - prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - - if _logits_processor_threadpool is not None: - logits_row_ids_and_logits_row_futures.append( - (logits_row_idx, - _logits_processor_threadpool.submit( - _apply_logits_processors_single_seq, logits_row, - logits_processors, past_tokens_ids, - prompt_tokens_ids))) - else: - logits[logits_row_idx] = \ - _apply_logits_processors_single_seq( - logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids) - - logits_processed += len(seq_group.sample_indices) + len( - seq_group.prompt_logprob_indices) - - for logits_row_idx, future in logits_row_ids_and_logits_row_futures: - logits[logits_row_idx] = future.result() - - if found_logits_processors: - # verifies that no rows in logits were missed unexpectedly - assert logits_processed == logits.shape[0] - return logits - - -def _apply_logits_processors_single_seq(logits_row, logits_processors, - past_tokens_ids, - prompt_tokens_ids) -> torch.Tensor: - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, logits_row) - return logits_row diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py deleted file mode 100644 index 9d93cad2420ad..0000000000000 --- a/vllm/model_executor/layers/sampler.py +++ /dev/null @@ -1,1198 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A layer that samples the next tokens from the model's outputs.""" -import itertools -from collections.abc import Iterator -from dataclasses import dataclass -from importlib.util import find_spec -from math import inf -from typing import Optional, Union - -import msgspec -import torch -import torch.nn as nn - -import vllm.envs as envs -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.model_executor.layers.utils import apply_penalties -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors, - SequenceGroupToSample) -from vllm.sampling_params import SamplingType -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, SequenceOutput) - -if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - # yapf: disable - from flashinfer.sampling import ( - top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) - - # yapf: enable -else: - flashinfer_top_k_top_p_sampling = None - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -def get_sampler() -> torch.nn.Module: - if envs.VLLM_USE_V1: - # Lazy import: the v1 package isn't distributed - from vllm.v1.sample.sampler import Sampler as V1Sampler - return V1Sampler() - return Sampler() - - -# (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = list[tuple[list[int], list[int]]] - -# Types of temporary data structures used for -# computing sample_result -SampleMetadataType = dict[SamplingType, tuple[list[int], - list[SequenceGroupToSample]]] -MultinomialSamplesType = dict[SamplingType, torch.Tensor] -SampleResultsDictType = dict[int, tuple[list[int], list[int]]] - - -# Encapsulates temporary data structures for computing -# sample_result. -# -# * For multi-step scheduling: must be returned -# by `Sampler.forward()` and used later to compute the pythonized -# sample_result -# -# * For single-step scheduling: consumed immediately -# inside `Sampler.forward()` to compute pythonized sample_result. -@dataclass -class SampleResultArgsType: - sample_metadata: SampleMetadataType - multinomial_samples: MultinomialSamplesType - sample_results_dict: SampleResultsDictType - sampling_metadata: SamplingMetadata - greedy_samples: Optional[torch.Tensor] - - -# Union of non-deferred (single-step scheduling) -# vs deferred (multi-step scheduling) -# sample result types -MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] - -# Abbreviation of the _sample() return type -SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] - - -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: list[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # Holds either (1) the pythonized sampler result (single-step scheduling) - # or (2) what will be arguments for later deferred pythonization of the - # sampler result (muliti-step scheduling) - deferred_sample_results_args: Optional[SampleResultArgsType] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # On-device tensor containing the sampled token embeddings (embeddings - # corresponding to the sampled token ids). Used when prompt embeddings are - # specified in lieu of prompt token ids or text. - sampled_token_embeds: Optional[torch.Tensor] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: - return iter(self.outputs) - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return (f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr})") - - -class Sampler(nn.Module): - """Samples the next tokens from the model's outputs. - - This layer does the following: - 1. Discard the hidden states that are not used for sampling (i.e., all - tokens except the final one in each prompt). - 2. Compute the logits for the next tokens. - 3. Apply presence, frequency and repetition penalties. - 4. Apply temperature scaling. - 5. Apply top-p and top-k truncation. - 6. Sample the next tokens. - Here, each sequence group within the batch can have different sampling - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). - - The structure of the logits tensor is coupled with the seq_groups in - sampling_metadata. Typically, each sequence in each seq_group has one row in - logits for the next token to be sampled; however, for a seq_group with a - prompt request with the prompt_logprobs sampling parameter, there are rows - in logits for each token in the input prompt. - """ - - def __init__(self): - super().__init__() - - # Whether or not the SamplerOutput should have on-device tensors - # containing the sampled token ids and probabilities. This is used by - # speculative decoding and when prompt embeddings are specified. - self.include_gpu_probs_tensor = False - self.should_modify_greedy_probs_inplace = False - - def _init_sampling_tensors( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ): - """The goal here is to reuse sampling tensors between similar decode - runs. This is possible because sampling logic does not change between - decodes of the same sequences. - """ - _, vocab_size = logits.shape - - # First free any existing stored sampling tensors. - # This is necessary because some sampling tensors may - # have pinned memory. - self._sampling_tensors = None - - # Initialize new sampling tensors - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) - - self._sampling_tensors = sampling_tensors - self._do_penalties = do_penalties - self._do_top_p_top_k = do_top_p_top_k - self._do_min_p = do_min_p - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """ - Single-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Pythonize sampling result & logprobs tensor - - Multi-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Defer Pythonization of sampling result & logprobs - tensor - * Encapsulate arguments required for deferred Pythonization - in the - [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput] - structure - - Args: - logits: (num_tokens, vocab_size). - sampling_metadata: Metadata for sampling. - """ - assert logits is not None - _, vocab_size = logits.shape - - # Prepare sampling tensors with pinned memory to avoid blocking. - if not sampling_metadata.reuse_sampling_tensors: - self._init_sampling_tensors(logits, sampling_metadata) - elif self._do_penalties: - # In this case, the sampling tensors logic depends on - # "output_tokens" of a sequence. As a result, we cannot - # reuse sampling tensors, since "output_tokens" changes - # between decode runs. - self._init_sampling_tensors(logits, sampling_metadata) - - assert self._sampling_tensors is not None - sampling_tensors = self._sampling_tensors - do_penalties = self._do_penalties - do_top_p_top_k = self._do_top_p_top_k - do_min_p = self._do_min_p - - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - # Apply presence and frequency penalties. - if do_penalties: - logits = apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - - # Use float32 to apply temperature scaling. - # Use in-place division to avoid creating a new tensor. - logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - - if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=self.include_gpu_probs_tensor, - modify_greedy_probs=self._should_modify_greedy_probs_inplace, - ) - - if self.include_gpu_probs_tensor: - # Since we will defer sampler result Pythonization, - # preserve GPU-side tensors in support of later - # deferred pythonization of logprobs - assert maybe_sampled_tokens_tensor is not None - on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) - else: - # Since Pythonization has already happened, don't preserve - # GPU-side tensors. - on_device_tensors = None - - # Get the logprobs query results. - prompt_logprobs = None - sample_logprobs = None - if not sampling_metadata.skip_sampler_cpu_output: - # Pythonize logprobs now (GPU -> CPU); do not defer. - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - prompt_logprobs, sample_logprobs = get_logprobs( - logprobs, sampling_metadata, maybe_deferred_sample_results) - - return _build_sampler_output( - maybe_deferred_sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors, - skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) - - @property - def _should_modify_greedy_probs_inplace(self) -> bool: - """Whether or not the sampler should modify the probability distribution - of greedily-sampled tokens such that multinomial sampling would sample - the greedily-sampled token. - - In other words, if True then we set the probability of the greedily- - sampled token to 1. - - This is used by speculative decoding, which requires that the sampling - method be encoded into the probability distribution. - """ - return self.should_modify_greedy_probs_inplace - - -def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - have not been generated yet - """ - # list of indices in logits that will be set to -inf - logits_to_penalize: list[tuple[int, int]] = [] - logits_applied = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - - sample_indices = seq_group.sample_indices - logits_applied += len(sample_indices) + len( - seq_group.prompt_logprob_indices) - if not seq_group.do_sample: - continue - - start_idx = sample_indices[0] - min_tokens = sampling_params.min_tokens - token_ids_to_penalize = sampling_params.all_stop_token_ids - if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: list[int] = [] - for j, seq_id in enumerate(seq_ids): - seq_data = seq_group.seq_data[seq_id] - if len(seq_data.output_token_ids_array) < min_tokens: - seqs_to_penalize.append(j) - - if seqs_to_penalize: - # convert to the index into logits - seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] - # itertools.product pairs each seq index with every token id - logits_to_penalize.extend( - itertools.product(seqs_to_penalize, token_ids_to_penalize)) - - if logits_to_penalize: - # use zip and * to group indices along each dimension - # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) - logits[tuple(zip(*logits_to_penalize))] = -float("inf") - - # verifies that no rows in logits were missed unexpectedly - assert logits_applied == logits.shape[0] - return logits - - -def _apply_top_k_top_p( - logits: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = torch.empty_like(logits_sort).scatter_(dim=-1, - index=logits_idx, - src=logits_sort) - return logits - - -def _apply_min_p( - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Adapted from - https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 - """ - probs = torch.softmax(logits, dim=-1) - top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs - tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill_(tokens_to_remove, -float("inf")) - - return logits - - -def _greedy_sample( - selected_seq_groups: list[SequenceGroupToSample], - samples: torch.Tensor, -) -> SampleResultType: - """Run greedy sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - samples: (num_selected_samples,) A tensor of samples. The length of - samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - samples_lst = samples.tolist() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - num_parent_seqs = len(seq_ids) - assert num_parent_seqs == 1, ( - "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples_lst[sample_idx]] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _random_sample( - selected_seq_groups: list[SequenceGroupToSample], - random_samples: torch.Tensor, -) -> SampleResultType: - """Run random sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - random_samples: (num_selected_samples,) A tensor of samples. The - length of samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # Find the maximum n value of the prompt phase requests. - random_samples = random_samples.cpu() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - num_parent_seqs = len(seq_ids) - if is_prompt: - # Prompt phase. - parent_ids = [0] * sampling_params.n - next_token_ids = random_samples[ - sample_idx, :sampling_params.n].tolist() - else: - # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -def _multinomial( - probs: torch.Tensor, - num_samples: int, - seq_groups: Optional[list[SequenceGroupToSample]] = None, -) -> torch.Tensor: - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - q = torch.empty_like(probs) - if seq_groups is None: - q.exponential_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - q[sample_idx:sample_idx + - stride].exponential_(generator=seq_group.generator) - sample_idx += stride - return probs.div_(q).argmax(dim=1).view(-1, num_samples) - - -def _top_k_top_p_multinomial_with_flashinfer( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - top_ks = top_ks.repeat_interleave(num_samples) - top_ps = top_ps.repeat_interleave(num_samples) - batch_next_token_ids = flashinfer_top_k_top_p_sampling( - probs, - top_ks, - top_ps, - ) - return batch_next_token_ids.view(-1, num_samples) - - -def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType) -> SampleResultType: - '''This function consumes GPU-side sampler results and computes - Pythonized CPU-side sampler results (GPU -> CPU sync.) - - Single-step scheduling: this function is invoked at sampling-time - for immediate Pythonization. - - Multi-step scheduling: Pythonization is deferred until after multiple - GPU-side steps have been completed. - - Args: - sample_result_args: GPU-side inputs to the Pythonization process - - Returns: - Pythonized sampler results - ''' - - ( - sample_metadata, - sampling_metadata, - greedy_samples, - multinomial_samples, - sample_results_dict, - ) = ( - sample_result_args.sample_metadata, - sample_result_args.sampling_metadata, - sample_result_args.greedy_samples, - sample_result_args.multinomial_samples, - sample_result_args.sample_results_dict, - ) - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - return [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - - -def _sample_with_torch( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - '''Torch-oriented _sample() implementation. - - Single-step scheduling: - * Perform GPU-side sampling computation - * Immediately Pythonize sampling result - - Multi-step scheduling: - * Perform GPU-side sampling computation - * Defer Pythonization & preserve GPU-side - tensors required for Pythonization - ''' - - categorized_seq_group_ids: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: SampleResultsDictType = {} - sample_metadata: SampleMetadataType = {} - multinomial_samples: MultinomialSamplesType = {} - greedy_samples: Optional[torch.Tensor] = None - - # Create output tensor for sampled token ids. - if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), - VLLM_INVALID_TOKEN_ID, - dtype=torch.long, - device=logprobs.device) - else: - sampled_token_ids_tensor = None - - # Counterintuitively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups) - long_sample_indices = sample_indices.long() - if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], - dim=-1) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = greedy_samples.unsqueeze(-1) - - if modify_greedy_probs: - # If required, modify the probabilities such that sampling from - # the modified distribution would always sample the argmax - # token id. - _modify_greedy_probs_inplace(logprobs, probs, - long_sample_indices, - greedy_samples) - - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_n_in_batch = 1 - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_n_in_batch = max(max_n_in_batch, sampling_params.n) - seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else - seq_groups) - - if flashinfer_top_k_top_p_sampling is not None: - logger.warning("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") - - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_n_in_batch, - seq_groups=seq_groups_arg) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[long_sample_indices] = \ - multinomial_samples[sampling_type].to(torch.long) - - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - # Encapsulate arguments for computing Pythonized sampler - # results, whether deferred or otherwise. - maybe_deferred_args = SampleResultArgsType( - sampling_metadata=sampling_metadata, - sample_metadata=sample_metadata, - multinomial_samples=multinomial_samples, - greedy_samples=greedy_samples, - sample_results_dict=sample_results_dict) - - if not sampling_metadata.skip_sampler_cpu_output: - # GPU<->CPU sync happens here. - # This also converts the sampler output to a Python object. - # Return Pythonized sampler result & sampled token ids - return get_pythonized_sample_results( - maybe_deferred_args), sampled_token_ids_tensor - else: - # Defer sampler result Pythonization; return deferred - # Pythonization args & sampled token ids - return ( - maybe_deferred_args, - sampled_token_ids_tensor, - ) - - -def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - """ - Args: - probs: (num_query_tokens_in_batch, num_vocab) - logprobs: (num_query_tokens_in_batch, num_vocab) - sampling_metadata: The metadata for a batch for sampling. - sampling_tensors: Tensors that include sampling related metadata. - - Returns: - (next_token_ids, parent_seq_ids) for each seq group in a batch. - If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. - """ - return _sample_with_torch( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=include_gpu_probs_tensor, - modify_greedy_probs=modify_greedy_probs, - ) - - -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - """ - This function calculates the ranks of the chosen tokens in a logprob tensor. - - Args: - x (torch.Tensor): 2D logprob tensor of shape (N, M) - where N is the no. of tokens and M is the vocab dim. - indices (torch.Tensor): List of chosen token indices. - - Returns: - torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank - of the chosen token in the input logprob tensor. - """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - result = (x > vals[:, None]) - del vals - return result.sum(1).add_(1) - - -def get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: SampleResultType, -) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: - """Return sample logprobs and prompt logprobs. - - The logic consists of 3 parts. - - Select indices to compute logprob from, ranks of token ids, and - the top k token ids from logprobs. - - Compute prompt logprobs if required. - - Compute sample logprobs if required. - - Args: - logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's - logprob per vocab. Sequence groups' query tokens are batched in a - single flattened tensor. For example, assuming there are N - seq groups, it is sorted by prefill tokens for seq_group_1 (if - prompt logprob is enabled), decode tokens for seq_group_1 (if - sampling is required), prefill tokens for seq_group_2, ... - sampling_metadata: The sampling metadata. - sample_results: (num_seq_groups) The tuple of (next_token_ids, - parent_ids) for each sequence group. When beam search is enabled, - sample_results can contain different number of seq_ids from - sampling_metadata.seq_groups. It is because beam search creates - 2 * BEAM_WIDTH number of samples (whereas there are only up to - BEAM_WIDTH number of seq_ids). - - Returns: - A tuple of prompt and sample logprobs per sequence group in a batch. - """ - # The index of query token to calculate logprobs. It includes both - # prompt and sample logprob indices. - query_indices: list[int] = [] - # The next token ids to get the logprob value from. - next_token_ids: list[int] = [] - # The largest requested number of logprobs. We find logprobs as many as the - # largest num logprobs in this API. If every logprobs is None, it will be - # set to -1. - largest_num_logprobs = -1 - - # Select indices to compute logprob from, ranks of token ids, and the top - # k token ids from logprobs. - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - sample_results): - sampling_params = seq_group.sampling_params - - # Update indices and tokens for prompt logprobs. - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - query_indices.extend(seq_group.prompt_logprob_indices) - next_token_ids.extend(next_prompt_tokens) - - # Update indices and next tokenes for sample logprob. - if seq_group.do_sample: - token_ids, parent_seq_ids = sample_result - # NOTE: We cannot directly use sample_indices because - # sample_indices only contain parent seq_ids of a previous step. - # The current step may have different number of seq_ids, and - # we can obtain it from `sample_result[1]`. - query_idx = seq_group.sample_indices[0] - query_indices.extend( - [query_idx + parent_id for parent_id in parent_seq_ids]) - next_token_ids.extend(token_ids) - - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - - assert len(next_token_ids) == len(query_indices) - - if len(query_indices) == 0: - empty_sampled_logprob: SampleLogprobs = [] - empty_prompt_logprob: Optional[PromptLogprobs] = None - num_seq_groups = len(sampling_metadata.seq_groups) - return [empty_prompt_logprob - ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups - - selected_logprobs, ranks = None, None - top_logprobs, top_token_ids = None, None - - # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can - # skip the whole logprob calculation. - if largest_num_logprobs >= 0: - query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, - device=logprobs.device) - - # (num_selected_query_tokens, num_logprobs). Note that query_indices can - # contain duplicates if beam search is enabled. - selected_logprobs = logprobs[[ - query_indices_gpu, - next_token_ids_gpu, - ]] - ranks = _get_ranks( - logprobs[query_indices_gpu], - next_token_ids_gpu, - ) - assert selected_logprobs.shape[0] == ranks.shape[0] - - # We need to compute top k only if there exists logprobs > 0. - if largest_num_logprobs > 0: - # Logprobs of topk tokens for a batch of sequence groups. - # (num_query_tokens_across_batch). - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - top_logprobs = top_logprobs.to('cpu') - top_token_ids = top_token_ids.to('cpu') - - selected_logprobs = selected_logprobs.to('cpu') - ranks = ranks.to('cpu') - - # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: list[SampleLogprobs] = [] - top_logprob_idx = 0 - selected_logprobs_idx = 0 - - for seq_group, sample_result in zip(sampling_metadata.seq_groups, - sample_results): - (prompt_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_prompt_logprob_if_needed( - seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, - selected_logprobs_idx, top_logprob_idx) - prompt_logprobs_per_seq_group.append(prompt_logprobs) - - (sampled_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_sampled_logprob_if_needed( - seq_group, sample_result, selected_logprobs, ranks, top_token_ids, - top_logprobs, selected_logprobs_idx, top_logprob_idx) - sample_logprobs_per_seq_group.append(sampled_logprobs) - - return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group - - -def _get_prompt_logprob_if_needed( - seq_group: SequenceGroupToSample, - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the prompt logprob from a sequence group if needed.""" - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - - # Find prompt logprobs - prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [] - num_logprobs = sampling_params.prompt_logprobs - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - # Pre-select indexes and create a list. It is faster than calling .item - # repetitively. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - - for idx, token_id in enumerate(next_prompt_tokens): - # Calculate the prompt logprob of the real prompt tokens. - # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: dict[int, tuple[float, int]] = { - token_id: (selected_logprob_items[idx], rank_items[idx]) - } - - # Add top K prompt logprobs along with its rank. - if num_logprobs > 0: - top_ids = top_token_ids[ - top_logprob_idx, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - prompt_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip(top_ids, top_probs, - top_ranks) - }) - prompt_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in prompt_logprobs_dict.items() - }) - # + 1 to go to the next prompt token. - top_logprob_idx += 1 - - # + len(next_prompt_tokens) to go to the next prompt. - selected_logprobs_idx += len(next_prompt_tokens) - return prompt_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _get_sampled_logprob_if_needed( - seq_group: SequenceGroupToSample, - sample_result: tuple[list[int], list[int]], - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the sample logprob if needed.""" - seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - sampled_logprobs: SampleLogprobs = [] - next_token_ids, parent_seq_ids = sample_result - - if seq_group.do_sample: - assert len(next_token_ids) > 0 - if num_logprobs is None: - for next_token_id in next_token_ids: - # Use a dummy logprob - sampled_logprobs.append({next_token_id: Logprob(inf)}) - else: - # Pre-select items from tensor. tolist() is faster than repetitive - # `.item()` calls. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - for idx, (next_token_id, parent_id) in enumerate( - zip(next_token_ids, parent_seq_ids)): - # Get the logprob of a sampled token. - sampled_logprobs_dict = { - next_token_id: - (selected_logprob_items[idx], rank_items[idx]) - } - if num_logprobs is not None and num_logprobs > 0: - # Get top K logprobs. - top_ids = top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx + parent_id, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - sampled_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip( - top_ids, top_probs, top_ranks) - }) - - sampled_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in - sampled_logprobs_dict.items() - }) - - # NOTE: This part of code is not intuitive. `selected_logprobs` include - # logprobs for the current step, which has len(next_token_ids) tokens - # per sequence group. `logprobs` includes logprobs from the previous - # steps, which has len(seq_ids) tokens per sequence group. - - # Iterate to the next sequence group in a batch. - selected_logprobs_idx += len(next_token_ids) - # Iterate to the next sequence group in a batch. - top_logprob_idx += len(seq_ids) - return sampled_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, - sample_indices: torch.Tensor, - greedy_samples: torch.Tensor) -> None: - """Modify the probability distributions of the greedily-sampled tokens such - that each sampled token has a "probability" of 1.0. This is required by - speculative decoding, which depends on the sampling method being encoded - within the probability distribution for correctness. - - # Why do we only need to do this for greedy sampling? - - vLLM's sampler performs the following steps for greedy or multinomial - (random) sampling: - 1. Get logits from model. - 2. Modify logits according to per-sequence sampling parameters. - - Multiply by temperature, top-k and top-p masking, penalize tokens - according to their frequency, etc. - 3. Sample a token. - - Random sampling simply samples from the modified probability - distribution. - - Greedy sampling performs `argmax` to obtain the token with the - highest likelihood. - - Ignoring greedy sampling for a moment, we find that the computed probability - distribution has the following property: we can sample from it independently - and find that the token sampled by the Sampler has a frequency corresponding - to how often we see it in our sampling. In other words, for tokens sampled - with vLLM's random SamplingType, the computed probability distribution - encodes the sampling methodology completely. - - Greedy sampling does not normally have this property. vLLM modifies logits - according to sampling params, then performs `argmax`, then returns the - sampled token and the computed probability distribution. If we sample from - the distribution, we'll find the likelihood of the greedily-sampled token - is not always 1.0. - - Since lossless speculative decoding requires that the sampling methodology - be encoded within the probability distribution, we are motivated to modify - the probability distribution such that the sampled token has probability 1 - when speculative decoding is used. - - NOTE: Alternatively, we could use an extremely low temperature to achieve - greedy sampling using multinomial computation and unite the codepaths. This - has implications on the overall design of the sampler, e.g. how to record - accurate logprobs for the user, so this improvement is deferred to later. - """ - # NOTE: logprobs are not modified so they can be returned to the user. - probs[sample_indices, :] = 0 - probs[sample_indices, greedy_samples] = 1.0 - - -def _build_sampler_output( - maybe_deferred_sample_results: MaybeDeferredSampleResultType, - sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], - sample_logprobs: Optional[list[SampleLogprobs]], - on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]], - skip_sampler_cpu_output: bool = False, -) -> SamplerOutput: - """Construct Python objects with the output of sampling. - - Args: - on_device_tensors: Tuple containing on-device tensors with the - probabilities used in sampling and the sampled token ids. This - allows post-processing without copies to CPU/serialization, e.g. in - speculative decoding rejection sampling. - """ - sampler_output: list[CompletionSequenceGroupOutput] = [] - - if skip_sampler_cpu_output: - assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) - deferred_sample_results_args = maybe_deferred_sample_results - else: - assert prompt_logprobs is not None - assert sample_logprobs is not None - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - assert len(sampling_metadata.seq_groups) \ - == len(maybe_deferred_sample_results) \ - == len(prompt_logprobs) \ - == len(sample_logprobs) - deferred_sample_results_args = None - - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - maybe_deferred_sample_results, - prompt_logprobs, sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: list[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip( - parent_ids, next_token_ids, group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, - group_prompt_logprobs)) - - # If not specified, store None values in SamplerOutput. - if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, - sampled_token_ids) = on_device_tensors - else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, - None) - - return SamplerOutput( - outputs=sampler_output, - sampled_token_probs=sampled_token_probs, - sampled_token_ids=sampled_token_ids, - logprobs=logprobs_tensor, - deferred_sample_results_args=deferred_sample_results_args) - - -def _get_next_prompt_tokens( - seq_group: SequenceGroupToSample) -> tuple[int, ...]: - """Get a list of next prompt tokens to compute logprob from a - given sequence group. - - It is used to compute prompt logprob. Imagine you have logprob for each - query token. Query token needs to know the next prompt token id to compute - prompt logprob. This is a helper to obtain next prompt token ids. - - This API has to be used only when the caller knows seq_group is in prefill - stage. - - Returns: - A list of next prompt tokens to compute logprob. - """ - assert seq_group.is_prompt, ( - "Caller should ensure the sequence group is in a prefill stage.") - seq_ids = seq_group.seq_ids - query_len = seq_group.query_len - assert query_len is not None - # prompt has only 1 seq id. - assert len(seq_ids) == 1 - seq_data = seq_group.seq_data[seq_ids[0]] - computed_len = seq_data.get_num_computed_tokens() - prompt_tokens = seq_data.prompt_token_ids - # +1 because we are looking for a next prompt token. - next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + query_len + 1, - len(prompt_tokens)) - next_prompt_tokens = prompt_tokens[ - next_token_index_start:next_token_index_end] - return next_prompt_tokens diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 58296131fadb9..13f4eebf1038e 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -672,21 +672,15 @@ def tensorize_vllm_model(engine_args: "EngineArgs", ) as stream: stream.write(encryption_params.key) - from vllm import LLMEngine - from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + assert envs.VLLM_USE_V1 - if not envs.VLLM_USE_V1: - engine = LLMEngine.from_engine_args(engine_args) - engine.model_executor.collective_rpc( - "save_tensorized_model", - kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, - ) - else: - engine = V1LLMEngine.from_vllm_config(engine_config) - engine.collective_rpc( - "save_tensorized_model", - kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, - ) + from vllm.v1.engine.llm_engine import LLMEngine + + engine = LLMEngine.from_vllm_config(engine_config) + engine.collective_rpc( + "save_tensorized_model", + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, + ) def tensorize_lora_adapter(lora_path: str, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f2c66763d0816..a72086da18c4d 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -11,6 +11,7 @@ import tempfile import time from collections import defaultdict from collections.abc import Generator +from contextlib import contextmanager from pathlib import Path from typing import Any, Callable, Optional, Union @@ -98,6 +99,49 @@ def get_lock(model_name_or_path: Union[str, Path], return lock +@contextmanager +def atomic_writer(filepath: Union[str, Path], + mode: str = 'w', + encoding: Optional[str] = None): + """ + Context manager that provides an atomic file writing routine. + + The context manager writes to a temporary file and, if successful, + atomically replaces the original file. + + Args: + filepath (str or Path): The path to the file to write. + mode (str): The file mode for the temporary file (e.g., 'w', 'wb'). + encoding (str): The encoding for text mode. + + Yields: + file object: A handle to the temporary file. + """ + # Create a temporary file in the same directory as the target file + # to ensure it's on the same filesystem for an atomic replace. + temp_dir = os.path.dirname(filepath) + temp_fd, temp_path = tempfile.mkstemp(dir=temp_dir) + + try: + # Open the temporary file for writing + with os.fdopen(temp_fd, mode=mode, encoding=encoding) as temp_file: + yield temp_file + + # If the 'with' block completes successfully, + # perform the atomic replace. + os.replace(temp_path, filepath) + + except Exception: + logger.exception( + "Error during atomic write. Original file '%s' not modified", + filepath) + raise + finally: + # Clean up the temporary file if it still exists. + if os.path.exists(temp_path): + os.remove(temp_path) + + def maybe_download_from_modelscope( model: str, revision: Optional[str] = None, diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index f6400b05e110a..6dab4ed14345f 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -566,10 +565,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index be82c2fd59644..1ee378af76c9f 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -399,11 +399,10 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds=inputs_embeds) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata) -> Optional[torch.Tensor]: + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # Compute final logits from hidden states (last pipeline rank only) - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index b6dd559968415..55d16fd75cebb 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -456,10 +455,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index a7cb6b35a4ab4..35c1adbdd00b6 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -644,10 +643,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 687c82ded9d0a..0f05f9b4efcd6 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -16,7 +16,6 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( get_optimal_tiled_canvas) from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, @@ -464,7 +463,5 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index ae25033410407..db8d0a8710471 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant @@ -421,10 +420,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 5f6025abf315c..82cd4a26a1baa 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -623,10 +622,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 397089f31cdf6..584981ef3ebfd 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -571,10 +570,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a3131aa3812ef..b7455fba62c02 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -12,7 +12,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -704,10 +703,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 4c37622b049c8..30816f72a2678 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP, SupportsQuant @@ -355,10 +354,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 7a56236483749..79d648d749c6a 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -1046,10 +1045,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) # Disallow image tokens which does not include special # begin-image and end-image tokens diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1fc2da3e4d7ca..879508400222f 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig @@ -437,10 +436,8 @@ class ChatGLMBaseModel(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 179cc2af8eb3f..6d67eb68d51a8 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, @@ -478,7 +477,5 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 7f87e31abdcd3..f3929ef3b5938 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -448,15 +447,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: is_not_lora = hasattr(self.model.embed_tokens, 'weight') if is_not_lora: logits = self.logits_processor(self.model.embed_tokens, - hidden_states, sampling_metadata) + hidden_states) else: logits = self.logits_processor(self.model.embed_tokens.base_layer, - hidden_states, sampling_metadata) + hidden_states) return logits diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 003cf4563a22f..f863b1da5505c 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -462,10 +461,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 59c9921881497..ffc843fe033cc 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -488,10 +487,8 @@ class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 2770ddebc48ab..ed7e7614800fc 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, DeepseekV3ForCausalLM) -from vllm.model_executor.sampling_metadata import SamplingMetadata from .utils import AutoWeightsLoader, maybe_prefix @@ -222,10 +221,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 8fbf16d206a86..92f311ab465b5 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .deepseek_v2 import (DeepseekV2DecoderLayer, @@ -124,15 +123,13 @@ class DeepSeekMultiTokenPredictor(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + mtp_layer.shared_head(hidden_states)) return logits @@ -161,11 +158,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, sampling_metadata, - spec_step_idx) + return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 636554bd648f2..a99a6679a5696 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -56,7 +56,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op @@ -914,10 +913,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index d7ae8206baca5..c8ed759d2e972 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -15,7 +15,6 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class @@ -647,10 +646,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 20555e48b73d4..2a09234b59ed1 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -52,7 +52,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -534,10 +533,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index ebab018ed67e7..d262e9e9da50e 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -591,10 +590,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 3396c67f42b7b..74b358034ef3d 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -39,7 +39,6 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -234,8 +233,9 @@ class Ernie4_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: # from vllm_flash_attn.flash_attn_interface import ( @@ -261,8 +261,8 @@ class Ernie4_5_VisionAttention(nn.Module): causal=False) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -281,6 +281,8 @@ class Ernie4_5_VisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -291,8 +293,8 @@ class Ernie4_5_VisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -1289,11 +1291,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """compute logits""" - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def _vision_forward( self, diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 7f791852ceb91..f55016f7ccb36 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .ernie45_moe import Ernie4_5_MoeMLP @@ -587,10 +586,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 57c5348874378..288fbe736c32f 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -33,11 +33,9 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -139,12 +137,10 @@ class ErnieMultiTokenPredictor(nn.Module): self, hidden_states: torch.Tensor, lm_head: ParallelLMHead, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(lm_head, hidden_states) return logits @@ -160,7 +156,6 @@ class ErnieMTP(nn.Module, SupportsPP): self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head")) - self.sampler = get_sampler() if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -182,19 +177,10 @@ class ErnieMTP(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, self.lm_head, - sampling_metadata, spec_step_idx) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index f503fb0f9364a..5dafcd595e4a5 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -534,10 +533,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 9f7d57d938140..c78eedff66700 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -517,10 +516,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 42c378e5c389a..0c50056d1c527 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig @@ -496,10 +495,8 @@ class FalconForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 757051b3b1447..83efdd2e433fc 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP @@ -675,10 +674,8 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 90af859ab92ec..53e9e6fe6e460 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -29,7 +29,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.persimmon import PersimmonForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -389,10 +388,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.language_model.logits_processor( - self.language_model.lm_head, hidden_states, sampling_metadata) + self.language_model.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 12eb27503870c..c19425b6cb6d6 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -412,10 +411,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 0bdb6c6bf7ae9..3f76e1e7d42a2 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -409,10 +408,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 7246308d59028..77c0ef8cb91d2 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from ...attention.layers.encoder_only_attention import EncoderOnlyAttention @@ -542,10 +541,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index bee9fbd2c084a..0630ee07c347e 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -14,7 +14,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -704,10 +703,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index ffec3408702c9..f4d288fd887e9 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant @@ -814,10 +813,8 @@ class Gemma3nForCausalLM(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata], ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 8d3079aee0dfb..2acdba54a257d 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -685,10 +684,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 5e2908a82c418..b9d5e24e9f6fa 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -289,10 +288,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index cbf327ce02b6b..56ec634386909 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -52,7 +52,6 @@ from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -315,8 +314,10 @@ class Glm4vVisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( @@ -341,8 +342,8 @@ class Glm4vVisionAttention(nn.Module): ) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -361,6 +362,8 @@ class Glm4vVisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -371,9 +374,8 @@ class Glm4vVisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -1651,10 +1653,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 1acbd18091fb3..947c6ce62f551 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -703,10 +702,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 322c5619c1783..c572978e62206 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name @@ -155,15 +154,13 @@ class Glm4MoeMultiTokenPredictor(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + mtp_layer.shared_head(hidden_states)) return logits @@ -192,11 +189,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, sampling_metadata, - spec_step_idx) + return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 0f6521e44e6be..24274db148bd7 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from ..layers.pooler import DispatchPooler, Pooler @@ -307,10 +306,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 745d0b7759991..162018450e7c0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -329,10 +328,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 77df6ae6f30c8..698387fab946c 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -329,10 +328,9 @@ class GPTJForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) + self.lm_head.bias) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e97db188e27eb..7570aefb6e96e 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -321,10 +320,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.embed_out, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.embed_out, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index b49fd0d8f88af..4fe59f91124dd 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import cdiv @@ -670,10 +669,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP): return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 4f9cc2532bd8c..795b38e724eab 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -463,11 +462,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds) return model_output - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 221023f1fb657..a5849184339b1 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -37,7 +37,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -776,12 +775,8 @@ class GraniteSpeechForConditionalGeneration( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits( - hidden_states, - sampling_metadata, - ) + return self.language_model.compute_logits(hidden_states) def load_weights( self, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index da16c72000c0e..07200fef4799d 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -511,11 +510,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 79c6d8146ba9c..e89a1a4a0f7d3 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -672,10 +671,8 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 0b568a4b22685..a5d118f084e6c 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE @@ -311,11 +310,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index a591134383371..996e41fe84ff3 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -528,10 +527,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 4110c8a1fd08d..8a23a6b45bc70 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -54,7 +54,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -1004,10 +1003,8 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 870addd0dcbca..54167f9f10995 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -31,7 +31,6 @@ from transformers.modeling_utils import no_init_weights from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -962,10 +961,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights( self, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 9153a0e2c1e5a..18446d126b51a 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -738,10 +737,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 19a3ef1a3b800..8fdf70e35a2b8 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -13,11 +13,9 @@ from vllm.utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler - from vllm.model_executor.sampling_metadata import SamplingMetadata else: VllmConfig = Any Pooler = Any - SamplingMetadata = Any logger = init_logger(__name__) @@ -100,7 +98,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - sampling_metadata: SamplingMetadata, ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 118cce810a1f2..892188c047228 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -164,23 +163,15 @@ class InternParallelAttention(nn.Module): self.tp_size) self.scale = self.head_dim**-0.5 - if use_data_parallel: - self.qkv = ReplicatedLinear( - self.embed_dim, - 3 * self.head_dim * self.num_heads, - bias=config.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) - else: - self.qkv = QKVParallelLinear( - self.embed_dim, - self.head_dim, - num_dummy_heads + self.num_heads, - bias=config.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) self.qk_normalization = config.qk_normalization @@ -192,20 +183,13 @@ class InternParallelAttention(nn.Module): eps=config.layer_norm_eps, var_hidden_size=self.embed_dim) - if use_data_parallel: - self.proj = ReplicatedLinear( - self.dummy_dim, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) - else: - self.proj = RowParallelLinear( - self.dummy_dim, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -236,72 +220,6 @@ class InternParallelAttention(nn.Module): return out -class InternSdpaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: PretrainedConfig, - *, - num_dummy_heads: int = 0, - ) -> None: - super().__init__() - - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') - - # Additional dummy heads are used to enable TP for common GPU counts. - self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim - - self.scale = self.head_dim**-0.5 - self.qkv = nn.Linear(self.embed_dim, - 3 * self.dummy_dim, - bias=config.qkv_bias) - - self.qk_normalization = config.qk_normalization - - if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - - self.proj = nn.Linear(self.dummy_dim, self.embed_dim) - - # Use unified MultiHeadAttention with automatic backend selection - self.attn = MultiHeadAttention(self.num_heads, self.head_dim, - self.scale) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.chunk(3, dim=-1) - - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - - if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) - k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - - # Use unified MultiHeadAttention with automatic backend selection - x = self.attn(q, k, v) - - x = self.proj(x) - return x - - class InternMLP(nn.Module): def __init__( @@ -315,20 +233,18 @@ class InternMLP(nn.Module): self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -385,19 +301,19 @@ class InternVisionEncoderLayer(nn.Module): use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable - # tp_size = get_tensor_model_parallel_world_size() tp_size = (1 if use_data_parallel else get_tensor_model_parallel_world_size()) num_heads = config.num_attention_heads - if (num_heads + num_dummy_heads) % tp_size == 0: - return InternParallelAttention(config, - quant_config=quant_config, - num_dummy_heads=num_dummy_heads, - prefix=prefix, - use_data_parallel=use_data_parallel) - - return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) + # if the number of heads is not divisible by tp_size, + # we also disable Attention's TP + use_data_parallel = (use_data_parallel + or (num_heads + num_dummy_heads) % tp_size != 0) + return InternParallelAttention(config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads, + prefix=prefix, + use_data_parallel=use_data_parallel) def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index ce94328797ed6..221ff08b43843 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -358,10 +357,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.output, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.output, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index b59d1b88cf5ce..ba72c288b2b12 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -21,7 +21,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -812,10 +811,8 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 6a5c565b52e85..f4004e518e3ba 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -1399,10 +1398,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 4fee8c32fd581..0eb1578b43610 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig @@ -332,10 +331,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5b8fbc7226866..12a49029195ff 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -32,7 +32,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -581,10 +580,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index afe33b4d4ad26..2e5e276cc1c7d 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -21,7 +21,6 @@ from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -1556,10 +1555,8 @@ class BaseKeyeModule(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 94a5933a61416..f554077935bf3 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -67,7 +67,6 @@ from vllm.model_executor.models.interfaces import (SupportsMultiModal, SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.utils import merge_multimodal_embeddings -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -484,10 +483,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, return hidden_states def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, **kwargs) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, **kwargs) + logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 927f78c4e4b45..dd97afbeb668a 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, @@ -542,10 +541,8 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f8ea2111fed57..1b03cbef501b3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP @@ -601,10 +600,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7027138dfcb17..fb10af6c53c90 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) -from vllm.v1.sample.metadata import SamplingMetadata from .utils import AutoWeightsLoader, maybe_prefix @@ -244,10 +243,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) if self.draft_id_to_target_id is None: assert logits.shape[1] == self.config.vocab_size, \ "Expected logits to have shape " \ diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4f15e1b5762ea..e2d7b9f23b28a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -760,10 +759,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index beb3c33100595..c9133fde14552 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -13,7 +13,6 @@ from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import ImageSize @@ -563,10 +562,8 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens]. def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index cf9852de633f3..610fb188d57d2 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -13,7 +13,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig, from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -464,10 +463,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 46d54452a52d8..cee9ddaf94cc4 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -14,7 +14,6 @@ from transformers.models.llava_onevision.modeling_llava_onevision import ( from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -934,10 +933,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9d1017dac8aa1..36141a5d50641 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -26,7 +26,6 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -299,10 +298,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index b1a4138cb8f6c..9c3108146d2e5 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -30,7 +30,6 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -335,10 +334,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 6ba8ad372c95a..0ae59dc8dfc23 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -2,18 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from .utils import maybe_prefix @@ -105,12 +102,13 @@ class Medusa(nn.Module): return [block(hidden_states) for block in self.blocks] def compute_logits( - self, hidden_states: list[torch.Tensor], - sampling_metadata: SamplingMetadata) -> list[torch.Tensor]: + self, + hidden_states: list[torch.Tensor], + ) -> list[torch.Tensor]: logits_lst: list[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): - _logits = self.logits_processor(lm_head, hs, sampling_metadata) + _logits = self.logits_processor(lm_head, hs) if _logits is None: # _logits should only be None on rank > 0, in which case @@ -130,57 +128,6 @@ class Medusa(nn.Module): return logits_lst - def sample( - self, - logits: list[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - logits = torch.stack(logits, dim=0).float() - logprobs = torch.log_softmax(logits, dim=-1) - token_ids = logits.argmax(-1) # support only top-1 for now - probs = torch.softmax(logits, dim=-1) - - token_id_list = [] - token_prob_list = [] - token_logprob_list = [] - - for idx, seq_group in enumerate(sampling_metadata.seq_groups): - token_id_list.append(token_ids[:, seq_group.sample_indices]) - token_prob_list.append(probs[:, seq_group.sample_indices]) - token_logprob_list.append(logprobs[:, seq_group.sample_indices]) - - outputs: list[Optional[SamplerOutput]] = [] - for idx in range(len(sampling_metadata.seq_groups)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_prob_list[idx].squeeze(1), - logprobs=token_logprob_list[idx].squeeze(1), - sampled_token_ids=token_id_list[idx].squeeze(1), - )) - - return outputs - - def generate_proposals( - self, - previous_hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - # During preemption, we may receive an empty tensor (batch_size=0) - if previous_hidden_states.size(0) == 0: - # Return None to signal the Top1Proposer that no proposals - # were generated for this batch, allowing it to handle this - # special case appropriately - return None - - return self.sample( - logits=self.compute_logits( - hidden_states=self.forward(previous_hidden_states), - sampling_metadata=sampling_metadata, - ), - sampling_metadata=sampling_metadata, - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 140800dd41c76..82648ba668ca5 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -784,9 +783,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.decoder.compute_logits(hidden_states, sampling_metadata) + return self.decoder.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index ea5292d0df202..d256c1f3eed7b 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix @@ -183,9 +182,7 @@ class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: hidden_states = self.model.norm(hidden_states) - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 09194e9f95d0e..b4abe458e4771 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix @@ -140,12 +139,10 @@ class MiMoMultiTokenPredictor(nn.Module): self, hidden_states: torch.Tensor, lm_head: ParallelLMHead, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)] - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(lm_head, hidden_states) return logits @@ -178,11 +175,10 @@ class MiMoMTP(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, self.lm_head, - sampling_metadata, spec_step_idx) + spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 240c23ea2b25d..0986ea07406a9 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -583,10 +582,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 848a97b8bb2a0..2af0d546ce63d 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -376,10 +375,8 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 9b2d84e32151a..a17c4f004d75c 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -50,7 +50,6 @@ from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1194,9 +1193,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.llm.compute_logits(hidden_states, sampling_metadata) + return self.llm.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 6ce883be0a83c..1d2c7dea811e0 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -742,10 +741,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states.float(), - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states.float()) return logits diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index cc7db849a28bf..b2f020f3323e8 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors @@ -420,10 +419,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index d15776a39362d..94e3d7234b6f4 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -606,10 +605,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8b3474d809532..bebf0b5adac52 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -594,10 +593,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 2f0e8a2a5e575..131a66b713235 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -856,10 +855,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def separate_weights( self, diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc188..d057eb49a62d1 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -8,9 +8,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module): self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0) - self.sampler = get_sampler() - def generate_proposals( - self, - input_ids: torch.Tensor, - previous_hidden_states: torch.Tensor, - num_predict_tokens: int, - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - if num_predict_tokens > self.max_speculative_tokens: - raise ValueError(f"Max speculative tokens for model is " - f"{self.max_speculative_tokens}, but " - f"{num_predict_tokens} were requested") + # NOTE(woosuk): This method is commented out because it is old code + # using V0. We should either port it to V1 or remove it. - # b x 1 x d - previous_hidden_states = previous_hidden_states.unsqueeze(1) + # def generate_proposals( + # self, + # input_ids: torch.Tensor, + # previous_hidden_states: torch.Tensor, + # num_predict_tokens: int, + # sampling_metadata: SamplingMetadata, + # ) -> list[SamplerOutput]: + # if num_predict_tokens > self.max_speculative_tokens: + # raise ValueError(f"Max speculative tokens for model is " + # f"{self.max_speculative_tokens}, but " + # f"{num_predict_tokens} were requested") - if self.scale_input: - previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + # # b x 1 x d + # previous_hidden_states = previous_hidden_states.unsqueeze(1) - # b x 1 - last_tokens = input_ids.unsqueeze(1) + # if self.scale_input: + # previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 - next_tokens = [] + # # b x 1 + # last_tokens = input_ids.unsqueeze(1) - for head_index in range(num_predict_tokens): + # next_tokens = [] - # Project and predict - z = self.emb[head_index](last_tokens) # b k d - states = self.proj[head_index](previous_hidden_states) + # for head_index in range(num_predict_tokens): - # Weighted add of state_weight*state and emb_weight*z - # Let subsequent LN take care of denominator - # state_weight is close to 1, so shouldn't be any precision issues - states.add_(z, alpha=self.emb_weight / self.state_weight) + # # Project and predict + # z = self.emb[head_index](last_tokens) # b k d + # states = self.proj[head_index](previous_hidden_states) - states = self.activation(self.ln[head_index](states)) # b k d - previous_hidden_states = states - # TODO: not yet supporting top_k_tokens_per_head - states = states.flatten(0, 1) + # # Weighted add of state_weight*state and emb_weight*z + # # Let subsequent LN take care of denominator + # # state_weight is close to 1, so shouldn't be any precision issues + # states.add_(z, alpha=self.emb_weight / self.state_weight) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + # states = self.activation(self.ln[head_index](states)) # b k d + # previous_hidden_states = states + # # TODO: not yet supporting top_k_tokens_per_head + # states = states.flatten(0, 1) - output = self.sampler(logits, sampling_metadata) - last_tokens = output.sampled_token_ids - next_tokens.append(output) + # logits = self.logits_processor(self.head[head_index], states, + # sampling_metadata) - return next_tokens + # output = self.sampler(logits, sampling_metadata) + # last_tokens = output.sampled_token_ids + # next_tokens.append(output) + + # return next_tokens def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 2475fe1316097..201bf83cac581 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -26,7 +26,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm @@ -1527,10 +1526,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 48ac91fa6dde0..64d669e8ac3e1 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -320,10 +319,8 @@ class MPTForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 4f8652c006941..ae50f1aefc6f7 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -37,7 +37,6 @@ from vllm.model_executor.models.utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, @@ -1192,10 +1191,8 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): adapter_dict = dict(self.mlp1.named_parameters()) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 21f785e4b91af..6bb2f7392cb49 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig @@ -498,10 +497,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 1e1f0524bd063..ff571541a60a5 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -54,7 +54,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig from vllm.utils import LayerBlockType @@ -622,10 +621,8 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index f8e38dcd80b5a..d474c8db41b2f 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasNoOps, SupportsLoRA, SupportsPP @@ -468,10 +467,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index acda2027401d9..3abbff8c717d4 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -26,7 +26,6 @@ from vllm.model_executor.models.internvl import ( BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs, InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import NestedTensors @@ -632,10 +631,8 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 7be3c16528b52..9fa8760073c15 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -391,10 +390,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 3e4c580a11211..2e0b1fb2a13f7 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -54,7 +54,6 @@ from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Olmo3Config @@ -427,10 +426,8 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 892e967e4a21f..77ece544d4900 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -471,10 +470,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 365aab205b211..4c3ce9f61efb3 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -399,10 +398,8 @@ class OPTForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 944a9151d75d3..586fea343d6f9 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -339,10 +338,8 @@ class OrionForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index f1bb18716b40d..052e143b27f6e 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -39,7 +39,6 @@ from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -558,9 +557,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.llm.compute_logits(hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 5e4758ef8ea5d..f18e38ce154d2 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -19,7 +19,6 @@ from vllm.model_executor.models.siglip2navit import Siglip2NavitModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -630,9 +629,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.llm.compute_logits(hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index d6eec77ebcee5..aef5102304614 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,7 +9,6 @@ from transformers import BatchFeature, PaliGemmaConfig from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems, @@ -403,10 +402,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3e854e4d561ff..23fb7bb85215c 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -334,10 +333,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 6f39afbecf35b..9cf288e850057 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -59,7 +59,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -346,10 +345,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) + self.lm_head.bias) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4522c7043d01a..a2b201fe4228d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,7 +29,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -681,10 +680,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 25df9e9261d91..d2a3a8cc04969 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1451,10 +1450,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index c4548ee168bd7..ae153558e37aa 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -29,7 +29,6 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import make_layers, maybe_prefix @@ -695,18 +694,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - # If the shape is the same, it means that we have already - # prune hidden states manually. - prune_hidden_states = hidden_states.size( - 0) != sampling_metadata.selected_token_indices.size(0) processed_logits = self.logits_processor( self.lm_head, hidden_states, - sampling_metadata, self.embedding_bias, - prune_hidden_states=prune_hidden_states) + ) return processed_logits def load_weights( diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b3fc55dab6eca..47b5ad55ab2d0 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1257,10 +1256,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 01d16f1f2c387..3ce67ce37a7ab 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -667,10 +666,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 142d3251bc67a..7b197844c8b63 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalUUIDDict, NestedTensors) @@ -480,10 +479,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index ef96d272adfb5..33ee1cf44afd1 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -53,7 +52,6 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -932,7 +930,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -1024,20 +1021,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7470948499005..e0c08a6a88271 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -282,10 +281,8 @@ class QWenBaseModel(nn.Module): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e13e87b93429d..c536b0f60c30d 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved @@ -510,10 +509,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index a7e71309b6074..5f27230c913b4 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -50,7 +50,6 @@ from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.qwen2_audio import ( Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, @@ -955,10 +954,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index dbf486374bcf3..73b27572a8ebd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -43,7 +43,6 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm # yapf: disable @@ -1256,10 +1255,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index c797b71b5d2e1..762ab42e5929e 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -34,7 +34,6 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig, from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (AudioItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, @@ -481,10 +480,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6c6276a930453..6a9acaf2c3fe0 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -546,10 +545,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7f361678ba72e..b3c42c2572566 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -46,7 +46,6 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -377,8 +376,10 @@ class Qwen2VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: if self.attn_backend == _Backend.ROCM_AITER_FA: @@ -402,8 +403,8 @@ class Qwen2VisionAttention(nn.Module): causal=False) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -422,6 +423,8 @@ class Qwen2VisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -432,8 +435,8 @@ class Qwen2VisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -1523,10 +1526,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index dddb47048a1fc..ae72fd30c3993 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP @@ -328,10 +327,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 029309c49efd4..0661b3707ff44 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -54,7 +54,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -690,10 +689,8 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ce917f92bd2e5..24cebc5bfdd82 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -1208,10 +1207,8 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index c755eeb9b4eaa..c054339842e64 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, Qwen3NextRMSNorm) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig @@ -266,11 +265,9 @@ class Qwen3NextMTP(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 17375ff0959d7..aa28c07ddcebe 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -45,7 +45,6 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -270,6 +269,7 @@ class Qwen3_VisionTransformer(nn.Module): self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.use_data_parallel = use_data_parallel + self.num_grid_per_side = int(self.num_position_embeddings**0.5) # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack @@ -377,82 +377,68 @@ class Qwen3_VisionTransformer(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def fast_pos_embed_interpolate(self, grid_thw): - num_grid_per_side = int(self.num_position_embeddings**0.5) + def fast_pos_embed_interpolate(self, + grid_thw: list[list[int]]) -> torch.Tensor: - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + outputs = [] for t, h, w in grid_thw: h_idxs = torch.linspace(0, num_grid_per_side - 1, h, - dtype=torch.float32) + dtype=torch.float32, + device=self.device) w_idxs = torch.linspace(0, num_grid_per_side - 1, w, - dtype=torch.float32) + dtype=torch.float32, + device=self.device) - h_idxs_floor = h_idxs.to(torch.long) - w_idxs_floor = w_idxs.to(torch.long) - h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1, - max=num_grid_per_side - 1) - w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1, - max=num_grid_per_side - 1) + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor + dh = h_idxs - h_floor + dw = w_idxs - w_floor - idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T + - w_idxs_floor[None]).flatten().tolist() * t) - idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T + - w_idxs_ceil[None]).flatten().tolist() * t) - idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T + - w_idxs_floor[None]).flatten().tolist() * t) - idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T + - w_idxs_ceil[None]).flatten().tolist() * t) + w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) + w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) + w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) + w11 = (dh[:, None] * dw[None, :]).reshape(-1) - weight_list[0].extend( - ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t) - weight_list[1].extend( - ((1 - dh)[None].T * dw[None]).flatten().tolist() * t) - weight_list[2].extend( - (dh[None].T * (1 - dw)[None]).flatten().tolist() * t) - weight_list[3].extend( - (dh[None].T * dw[None]).flatten().tolist() * t) + idx00 = (h_floor[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx01 = (h_floor[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + idx10 = (h_ceil[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx11 = (h_ceil[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) - device = self.pos_embed.weight.device - dtype = self.pos_embed.weight.dtype + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) + weights = torch.stack([w00, w01, w10, w11], + dim=0).to(dtype=self.dtype, + device=self.device) + weights = weights.unsqueeze(-1) - p0 = self.pos_embed( - torch.tensor( - idx_list[0], dtype=torch.long, device=device)) * torch.tensor( - weight_list[0], dtype=dtype, device=device)[:, None] - p1 = self.pos_embed( - torch.tensor( - idx_list[1], dtype=torch.long, device=device)) * torch.tensor( - weight_list[1], dtype=dtype, device=device)[:, None] - p2 = self.pos_embed( - torch.tensor( - idx_list[2], dtype=torch.long, device=device)) * torch.tensor( - weight_list[2], dtype=dtype, device=device)[:, None] - p3 = self.pos_embed( - torch.tensor( - idx_list[3], dtype=torch.long, device=device)) * torch.tensor( - weight_list[3], dtype=dtype, device=device)[:, None] + embeds = self.pos_embed(indices) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = p0 + p1 + p2 + p3 - patch_pos_embeds = p0 + p1 + p2 + p3 - patch_pos_embeds = patch_pos_embeds.split( - [t * h * w for t, h, w in grid_thw]) - patch_pos_embeds_permute = [] - m_size = self.spatial_merge_size - for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): - pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size, - m_size, -1).permute(0, 1, 3, 2, 4, - 5).flatten(0, 4) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() + repeated = repeated.view(t, h // m_size, m_size, w // m_size, + m_size, hidden_dim) + repeated = repeated.permute(0, 1, 3, 2, 4, + 5).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) def compute_attn_mask_seqlen( self, @@ -477,12 +463,9 @@ class Qwen3_VisionTransformer(nn.Module): hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) - if isinstance(grid_thw, list): - grid_thw_tensor = torch.tensor(grid_thw, - device=hidden_states.device, - dtype=torch.int32) - else: - grid_thw_tensor = grid_thw + grid_thw_tensor = torch.tensor(grid_thw, + device=self.device, + dtype=torch.int32) cu_seqlens = torch.repeat_interleave( grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], @@ -1224,7 +1207,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw_list, rope_type="rope_3d") else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -1508,10 +1492,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1526,4 +1508,4 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, language_model="language_model", connector="model.visual.merger", tower_model="model.visual.", - ) \ No newline at end of file + ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 76f2bd087624c..5dc5d545bb9c5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -4,7 +4,9 @@ Whenever you add an architecture to this page, please also update `tests/models/registry.py` with example HuggingFace models for it. """ +import hashlib import importlib +import json import os import pickle import subprocess @@ -12,16 +14,19 @@ import sys import tempfile from abc import ABC, abstractmethod from collections.abc import Set -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import lru_cache +from pathlib import Path from typing import Callable, Optional, TypeVar, Union import torch.nn as nn import transformers +from vllm import envs from vllm.config import (ModelConfig, iter_architecture_defaults, try_match_architecture_defaults) from vllm.logger import init_logger +from vllm.logging_utils import logtime from vllm.transformers_utils.dynamic_module import ( try_get_class_from_dynamic_module) @@ -421,10 +426,91 @@ class _LazyRegisteredModel(_BaseRegisteredModel): module_name: str class_name: str - # Performed in another process to avoid initializing CUDA + @staticmethod + def _get_cache_dir() -> Path: + return Path(envs.VLLM_CACHE_ROOT) / "modelinfos" + + def _get_cache_filename(self) -> str: + cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-") + return f"{cls_name}.json" + + def _load_modelinfo_from_cache(self, + module_hash: str) -> _ModelInfo | None: + try: + try: + modelinfo_path = self._get_cache_dir( + ) / self._get_cache_filename() + with open(modelinfo_path, encoding="utf-8") as file: + mi_dict = json.load(file) + except FileNotFoundError: + logger.debug(("Cached model info file " + "for class %s.%s not found"), self.module_name, + self.class_name) + return None + + if mi_dict["hash"] != module_hash: + logger.debug(("Cached model info file " + "for class %s.%s is stale"), self.module_name, + self.class_name) + return None + + # file not changed, use cached _ModelInfo properties + return _ModelInfo(**mi_dict["modelinfo"]) + except Exception: + logger.exception(("Cached model info " + "for class %s.%s error. "), self.module_name, + self.class_name) + return None + + def _save_modelinfo_to_cache(self, mi: _ModelInfo, + module_hash: str) -> None: + """save dictionary json file to cache""" + from vllm.model_executor.model_loader.weight_utils import atomic_writer + try: + modelinfo_dict = { + "hash": module_hash, + "modelinfo": asdict(mi), + } + cache_dir = self._get_cache_dir() + cache_dir.mkdir(parents=True, exist_ok=True) + modelinfo_path = cache_dir / self._get_cache_filename() + with atomic_writer(modelinfo_path, encoding='utf-8') as f: + json.dump(modelinfo_dict, f, indent=2) + except Exception: + logger.exception("Error saving model info cache.") + + @logtime(logger=logger, msg="Registry inspect model class") def inspect_model_cls(self) -> _ModelInfo: - return _run_in_subprocess( + model_path = Path( + __file__).parent / f"{self.module_name.split('.')[-1]}.py" + + assert model_path.exists(), \ + f"Model {self.module_name} expected to be on path {model_path}" + with open(model_path, "rb") as f: + module_hash = hashlib.md5(f.read()).hexdigest() + + mi = self._load_modelinfo_from_cache(module_hash) + if mi is not None: + logger.debug(("Loaded model info " + "for class %s.%s from cache"), self.module_name, + self.class_name) + return mi + else: + logger.debug(("Cache model info " + "for class %s.%s miss. " + "Loading model instead."), self.module_name, + self.class_name) + + # Performed in another process to avoid initializing CUDA + mi = _run_in_subprocess( lambda: _ModelInfo.from_model_cls(self.load_model_cls())) + logger.debug("Loaded model info for class %s.%s", self.module_name, + self.class_name) + + # save cache file + self._save_modelinfo_to_cache(mi, module_hash) + + return mi def load_model_cls(self) -> type[nn.Module]: mod = importlib.import_module(self.module_name) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index e3c7c700f8fa1..a217c820fedf0 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -472,10 +471,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 9857ccdcbe2d4..893ce4497c319 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -897,10 +896,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 94c862258b7ad..c774171b9dcd2 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -47,7 +47,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -495,10 +494,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): inputs_embeds) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 9e880ebd50813..e4dfe8d5a9a3b 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -332,10 +331,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 62ff9b6182755..7f379ab95a03d 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -339,10 +338,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index b8733fa5e6129..0cce0c78f8dc6 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -26,11 +26,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -391,7 +389,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() @@ -407,20 +404,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: qkv_params_mapping = [ diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 2ba5f94ea3b88..f667266b77bfa 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property from itertools import product from math import ceil, sqrt from typing import Any, Literal, Optional, TypedDict, Union @@ -24,8 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -897,13 +894,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - @cached_property - def sampler(self): - if hasattr(self.language_model, "sampler"): - return self.language_model.sampler - - return get_sampler() - @property def device(self): return next(self.parameters()).device @@ -1064,17 +1054,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - return self.language_model.sample(logits, sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index c66867315e553..67cf3ccf315d1 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems @@ -638,10 +637,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 3bd4d10316ec6..475a68bc642b9 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalUUIDDict, @@ -798,10 +797,8 @@ class TransformersForCausalLM(TransformersBase): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f1f11c5fe8f00..12ae9487ad9dc 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -616,10 +615,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): inputs_embeds=inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 16a97389cd21b..b33e8d09c4be1 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -30,7 +30,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys # yapf: disable from vllm.model_executor.models.whisper import WhisperEncoder # yapf: enable -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, MultiModalUUIDDict, @@ -454,10 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) @classmethod def get_speech_to_text_config(cls, model_config: ModelConfig, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 41ae7b129782d..de3e4f0592a62 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -936,10 +935,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, return WhisperAudioInputs(input_features=input_features) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.proj_out, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.proj_out, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index e601bc3adb6e9..4350e38e02f96 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -1036,7 +1035,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """Compute logits for next token prediction. @@ -1047,8 +1045,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): Returns: Logits for next token prediction """ - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py deleted file mode 100644 index 2315f9dad5a5a..0000000000000 --- a/vllm/model_executor/sampling_metadata.py +++ /dev/null @@ -1,597 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from array import array -from dataclasses import dataclass -from typing import Optional - -import torch - -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) - -_SAMPLING_EPS = 1e-5 - - -@dataclass -class SequenceGroupToSample: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Sequence ids for the sequence group in a previous step. - seq_ids: list[int] - sampling_params: SamplingParams - # seq_id -> sequence data. - seq_data: dict[int, SequenceData] - # The length of the sequence (all tokens seen in the past + new token to - # compute attention) of the sequence group. None if it is in a decode - # stage. - seq_len: Optional[int] - # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seq_len if chunked - # prefill is enabled. - query_len: Optional[int] - # A random number generator for sampling. - generator: Optional[torch.Generator] - # True if the sequence group is in prefill stage. False if it is in a - # decode stage. - is_prompt: bool - # Query token indices from logits. to compute prompt logprob. Empty if - # prompt logprob is not required. - prompt_logprob_indices: list[int] - # Sample token indices from logits. Empty if sampling is not required. - sample_indices: list[int] - - @property - def do_sample(self): - return len(self.sample_indices) > 0 - - def __post_init__(self): - if len(self.prompt_logprob_indices) > 0: - assert self.sampling_params.prompt_logprobs is not None - if self.is_prompt: - assert self.seq_len is not None - assert self.query_len is not None - - -def gen_seq_group_to_sample_builder(num_seqs: int): - return lambda: SequenceGroupToSample( - seq_ids=[0] * num_seqs, - sampling_params=None, - seq_data=None, # type: ignore - seq_len=0, - query_len=0, - generator=None, - is_prompt=True, - prompt_logprob_indices=[], - sample_indices=[], - ) - - -class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations""" - - def __init__(self): - self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {} - - def get_cached_seq_group_to_sample(self, num_seqs): - if num_seqs not in self._seq_group_to_sample_cache: - self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( - gen_seq_group_to_sample_builder(num_seqs)) - - obj = self._seq_group_to_sample_cache[num_seqs].get_object() - return obj - - def reset(self): - for cache in self._seq_group_to_sample_cache.values(): - cache.reset() - - -class SamplingMetadata: - """Metadata for input sequences. Used in sampler. - - The usage is as follows; - ``` - hidden_states = execute_model(...) - logits = hidden_states[sampling_metadata.selected_token_indices] - sample(logits) - - def sample(logits): - # Use categorized_sample_indices for sampling.... - ``` - - Args: - seq_groups: List of batched sequence groups. - selected_token_indices: (num_query_tokens_to_logprob). Indices to find - logits from the initial model output hidden states. - categorized_sample_indices: SamplingType -> token indices to sample. - Each token indices is 2D tensor of (num_indices, num_indices) where - the first item means the sample index within the returned logit - (before pruning padding), and the second item means the sample - index after pruning using selected_token_indices. - For example, if the returned logit is [1, 2, 3], and we select - [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, - The first tuple is [1, 2] (sampled index within original logit), - and the second tuple is [0, 1] (sampled index within pruned logit). - num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU - serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling - tensors that are part of the sampler forward pass. Currently, - it is mainly used for multi-step decode. - - """ - - def __init__( - self, - seq_groups: list[SequenceGroupToSample], - selected_token_indices: torch.Tensor, - categorized_sample_indices: dict[SamplingType, torch.Tensor], - num_prompts: int, - skip_sampler_cpu_output: bool = False, - reuse_sampling_tensors: bool = False, - ) -> None: - self.seq_groups = seq_groups - self.selected_token_indices = selected_token_indices - self.categorized_sample_indices = categorized_sample_indices - self.num_prompts = num_prompts - self.skip_sampler_cpu_output = skip_sampler_cpu_output - self.reuse_sampling_tensors = reuse_sampling_tensors - - @staticmethod - def prepare( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - pin_memory: bool, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, - ) -> "SamplingMetadata": - ( - seq_groups, - selected_token_indices, - categorized_sample_indices, - num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device, generators, cache) - selected_token_indices = async_tensor_h2d( - selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory, - ) - categorized_sample_indices = { - t: - async_tensor_h2d( - seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory, - ) - for t, seq_ids in categorized_sample_indices.items() - } - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - num_prompts=num_prompts, - ) - return sampling_metadata - - def __repr__(self) -> str: - return ( - "SamplingMetadata(" - f"seq_groups={self.seq_groups}, " - f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices})") - - -def _prepare_seq_groups( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, -) -> tuple[ - list[SequenceGroupToSample], - list[int], - dict[SamplingType, list[int]], - int, -]: - """Prepare sequence groups and indices for sampling. - - Args: - seq_group_metadata_list: A list of sequence group to batch. - seq_lens: A list of sequence lens per sequence group. - Index of prompt len should match with seq_group_metadata_list. - query_lens: A list of query lengths. Prompt lens include the length - of entire prompt tokens, and it could be shorter. - device: A device to use for random number generators, - `SequenceGroupToSample.generator`. - generators: A store of per-request random number generators used - for seeded requests. - - Returns: - seq_groups: A list of sequence group to sample. - selected_token_indices: See the definition from `SamplingMetadata`. - categorized_sample_indices: See the definition from `SamplingMetadata`. - num_prompts: Total number of prompts from `seq_group_metadata_list`. - """ - # Batched sequence groups for the current model forward stsep. - seq_groups: list[SequenceGroupToSample] = [] - # A list of token indices to sample/compute logprob. It is used to - # prune the outcome logits from the model for the performance. - selected_token_indices: list[int] = [] - # Used for selected_token_indices. - model_output_idx = 0 - - # Sampling type -> ( - # indices to sample/prompt logprob within pruned output logits, - # indices to sample within pruned logits) - categorized_sample_indices: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - # Index of logits to compute logprob. Logits include both prompt logprob - # and sample logprob indices. - logit_idx = 0 - # Total number of prompts from given sequence groups. - num_prompts = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = seq_group_metadata.seq_data.keys() - - if cache is not None: - sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) - - for j, seq_id in enumerate(seq_ids): - sample_obj.seq_ids[j] = seq_id - - sample_obj.prompt_logprob_indices.clear() - sample_obj.sample_indices.clear() - - sampling_params = seq_group_metadata.sampling_params - is_prompt = seq_group_metadata.is_prompt - generator: Optional[torch.Generator] = None - # If the current seq group is in decode stage, it is None. - seq_len: Optional[int] = None - query_len: Optional[int] = None - prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices - if cache is not None else []) - sample_indices: list[int] = (sample_obj.sample_indices - if cache is not None else []) - do_sample = seq_group_metadata.do_sample - - if seq_group_metadata.is_prompt: - if sampling_params.seed is not None: - generator = torch.Generator(device=device).manual_seed( - sampling_params.seed) - if generators is not None: - generators[seq_group_metadata.request_id] = generator - - num_prompts += 1 - num_prefill_sample = len(seq_ids) - assert num_prefill_sample == 1 - assert query_lens is not None and seq_lens is not None - query_len, seq_len = query_lens[i], seq_lens[i] - # If we need sampling, exclude num_prefill_sample tokens from - # prompt logprob. - prompt_logprob_len = (query_len - num_prefill_sample - if do_sample else query_len) - sample_len = num_prefill_sample if do_sample else 0 - else: - # Decode - prompt_logprob_len = 0 - query_len = query_lens[i] if query_lens is not None and len( - query_lens) > 0 else 1 - sample_len = len(seq_ids) * query_len if do_sample else 0 - - if sampling_params.seed is not None and generators is not None: - generator = generators.get(seq_group_metadata.request_id) - - # Update indices to select from the model output. - """ - This blocks computes selected_token_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - """ - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + prompt_logprob_len)) - model_output_idx += prompt_logprob_len - if do_sample: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + sample_len)) - model_output_idx += sample_len - - # We now find indices for logprob computation and sampling. - """ - This block computes categorized_sample_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - def sample(logits): - # Use categorized_sample_indices for sampling. - # prompt_logprob_indices to find prompt logprob indices. - # sample_indices to find sample indices. - """ - - if sampling_params.prompt_logprobs is not None: - prompt_logprob_indices.extend( - range(logit_idx, logit_idx + prompt_logprob_len)) - logit_idx += prompt_logprob_len - if do_sample: - sample_indices.extend(range(logit_idx, logit_idx + sample_len)) - categorized_sample_indices[sampling_params.sampling_type].extend( - list(range(logit_idx, logit_idx + sample_len))) - logit_idx += sample_len - - if cache is not None: - sample_obj.sampling_params = sampling_params - sample_obj.seq_data = seq_group_metadata.seq_data - sample_obj.seq_len = seq_len - sample_obj.query_len = query_len - sample_obj.generator = generator - sample_obj.is_prompt = is_prompt - else: - sample_obj = SequenceGroupToSample( - seq_ids=list(seq_ids), - sampling_params=sampling_params, - seq_data=seq_group_metadata.seq_data, - seq_len=seq_len, - query_len=query_len, - generator=generator, - is_prompt=is_prompt, - prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices), - ) - - seq_groups.append(sample_obj) - - if cache is not None: - cache.reset() - - return (seq_groups, selected_token_indices, categorized_sample_indices, - num_prompts) - - -@dataclass -class SamplingTensors: - """Tensors for sampling.""" - - temperatures: torch.Tensor - top_ps: torch.Tensor - top_ks: torch.Tensor - min_ps: torch.Tensor - presence_penalties: torch.Tensor - frequency_penalties: torch.Tensor - repetition_penalties: torch.Tensor - prompt_tokens: torch.Tensor - output_tokens: torch.Tensor - - @classmethod - def from_sampling_metadata( - cls, - sampling_metadata: "SamplingMetadata", - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> tuple["SamplingTensors", bool, bool, bool]: - prompt_tokens: list[array] = [] - output_tokens: list[array] = [] - top_ks: list[int] = [] - temperatures: list[float] = [] - top_ps: list[float] = [] - min_ps: list[float] = [] - presence_penalties: list[float] = [] - frequency_penalties: list[float] = [] - repetition_penalties: list[float] = [] - do_penalties = False - do_top_p_top_k = False - do_min_p = False - - assert sampling_metadata.seq_groups is not None - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - temperature = sampling_params.temperature - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - top_p = sampling_params.top_p - min_p = sampling_params.min_p - - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k < 1 else top_k - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS - or top_k != vocab_size): - do_top_p_top_k = True - if not do_min_p and min_p > _SAMPLING_EPS: - do_min_p = True - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True - - is_prompt = seq_group.is_prompt - if is_prompt and sampling_params.prompt_logprobs is not None: - # For tokens in the prompt that we only need to get - # their logprobs - query_len = seq_group.query_len - assert query_len is not None - prefill_len = len(seq_group.prompt_logprob_indices) - temperatures += [temperature] * prefill_len - top_ps += [top_p] * prefill_len - top_ks += [top_k] * prefill_len - min_ps += [min_p] * prefill_len - presence_penalties += [0] * prefill_len - frequency_penalties += [0] * prefill_len - repetition_penalties += [1] * prefill_len - - if seq_group.do_sample: - sample_lens = len(seq_group.sample_indices) - assert sample_lens >= len(seq_ids) - temperatures += [temperature] * sample_lens - top_ps += [top_p] * sample_lens - top_ks += [top_k] * sample_lens - min_ps += [min_p] * sample_lens - presence_penalties += [p] * sample_lens - frequency_penalties += [f] * sample_lens - repetition_penalties += [r] * sample_lens - - if do_penalties: - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - prefill_len = len(seq_group.prompt_logprob_indices) - prompt_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - output_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - if seq_group.do_sample: - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids_array) - output_tokens.append(seq_data.output_token_ids_array) - - sampling_tensors = SamplingTensors.from_lists( - temperatures, - top_ps, - top_ks, - min_ps, - presence_penalties, - frequency_penalties, - repetition_penalties, - prompt_tokens, - output_tokens, - vocab_size, - device, - dtype, - ) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) - - @classmethod - def from_lists( - cls, - temperatures: list[float], - top_ps: list[float], - top_ks: list[int], - min_ps: list[float], - presence_penalties: list[float], - frequency_penalties: list[float], - repetition_penalties: list[float], - prompt_tokens: list[array], - output_tokens: list[array], - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> "SamplingTensors": - # Note that the performance will be very bad without - # pinned memory. - pin_memory = is_pin_memory_available() - - do_penalties = prompt_tokens or output_tokens - - if do_penalties: - prompt_t = make_tensor_with_pad( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - output_t = make_tensor_with_pad( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - else: - empty_tensor = torch.empty(0, device=device, dtype=torch.long) - prompt_t = empty_tensor - output_t = empty_tensor - - temperatures_t = torch.tensor( - temperatures, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ps_t = torch.tensor( - top_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - min_ps_t = torch.tensor( - min_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - presence_penalties_t = torch.tensor( - presence_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - frequency_penalties_t = torch.tensor( - frequency_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - repetition_penalties_t = torch.tensor( - repetition_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ks_t = torch.tensor( - top_ks, - device="cpu", - dtype=torch.int, - pin_memory=pin_memory, - ) - # Because the memory is pinned, we can do non-blocking - # transfer to device. - - return cls( - temperatures=temperatures_t.to(device=device, non_blocking=True), - top_ps=top_ps_t.to(device=device, non_blocking=True), - top_ks=top_ks_t.to(device=device, non_blocking=True), - min_ps=min_ps_t.to(device=device, non_blocking=True), - presence_penalties=presence_penalties_t.to(device=device, - non_blocking=True), - frequency_penalties=frequency_penalties_t.to(device=device, - non_blocking=True), - repetition_penalties=repetition_penalties_t.to(device=device, - non_blocking=True), - prompt_tokens=prompt_t.to(device=device, non_blocking=True), - output_tokens=output_t.to(device=device, non_blocking=True), - ) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index b7d4cd298e24f..7ffa732cf3708 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -15,7 +15,7 @@ is used by model runners to dispatch data processing according to the target model. Info: - [mm_processing](../../../design/mm_processing.html) + [mm_processing](../../../design/mm_processing.md) """ __all__ = [ diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ef8f1b2e17b47..e0edb3e883ed6 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,14 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar - -if TYPE_CHECKING: - from vllm.sequence import SequenceGroupMetadata - -from .inputs import MultiModalKwargs, PlaceholderRange +from typing import Generic, NamedTuple, TypeVar _T = TypeVar("_T") @@ -53,120 +47,6 @@ class MultiModalPlaceholderMap: self.dest_ranges = [] self.dest_len = 0 - @classmethod - def from_seq_group( - cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: - """ - Returns the multi-modal items that intersect with the portion of a - prompt (``seq_group``) represented by ``positions``, as well as a - ``MultiModalPlaceholderMap`` that relates the multi-modal embedding - vectors to their corresponding placeholders. - - Examples: - - ``` - Prompt: |AAAA BBBB What's in these images?| - Positions: |.................................| - - images = [A, B] - src_ranges = [(0, 4), (4, 8)] - dest_ranges = [(0, 4), (5, 9)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | ..... | - - images = [A, B] - src_ranges = [(2, 4), (4, 6)] - dest_ranges = [(0, 2), (3, 5)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | ......... | - - images = [B] - src_ranges = [(0, 4)] - dest_ranges = [(0, 4)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | .......................| - - images = [] - src_ranges = [] - dest_ranges = [] - ``` - """ - seq_mm_data = seq_group.multi_modal_data - seq_mm_placeholders = seq_group.multi_modal_placeholders - - if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs(), {} - - placeholder_maps = dict[str, MultiModalPlaceholderMap]() - - for modality, placeholders in seq_mm_placeholders.items(): - placeholder_map = MultiModalPlaceholderMap() - - if positions: - placeholder_map.append_items_from_seq_group( - positions, - # Dummy, since we don't care about intersecting items - [None] * len(placeholders), - placeholders, - ) - - placeholder_maps[modality] = placeholder_map - - return seq_mm_data, placeholder_maps - - def append_items_from_seq_group( - self, - positions: range, - multi_modal_items: list[_T], - multi_modal_placeholders: Sequence[PlaceholderRange], - ) -> list[_T]: - """ - Adds the multi-modal items that intersect ```positions`` to this - placeholder map and returns the intersecting items. - """ - intersecting_items = [] - - if len(multi_modal_items) != len(multi_modal_placeholders): - raise ValueError( - "Multi-modal placeholders and items must have the same length." - ) - for placeholder_dict, mm_item in zip(multi_modal_placeholders, - multi_modal_items): - placeholder = range( - placeholder_dict.offset, - placeholder_dict.offset + placeholder_dict.length, - ) - intersection = range( - max(positions.start, placeholder.start), - min(positions.stop, placeholder.stop), - ) - - if not intersection: - # Skip this multi-modal item. - continue - - token_embedding_range = range( - intersection.start - positions.start, - intersection.stop - positions.start, - ) - - multimodal_embedding_range = range( - intersection.start - placeholder.start + self.src_len, - intersection.stop - placeholder.start + self.src_len, - ) - - intersecting_items.append(mm_item) - self.dest_ranges.append(token_embedding_range) - self.src_ranges.append(multimodal_embedding_range) - self.src_len += len(placeholder) - - self.dest_len += len(positions) - return intersecting_items - def extend(self, other: "MultiModalPlaceholderMap"): """ Adds the placeholders from another ``MultiModalPlaceholderMap`` to this diff --git a/vllm/outputs.py b/vllm/outputs.py index 64bcfd472f2ad..4d8206bb2d830 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass @@ -14,9 +13,7 @@ from vllm.logger import init_logger from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (RequestMetrics, SequenceGroup, SequenceGroupBase, - SequenceStatus) +from vllm.sequence import RequestMetrics logger = init_logger(__name__) @@ -171,170 +168,6 @@ class RequestOutput: else: self.outputs.append(next_completion) - @classmethod - def from_seq_group( - cls, seq_group: SequenceGroup, use_cache: bool, - seq_id_to_seq_group: dict[str, SequenceGroupBase] - ) -> Optional["RequestOutput"]: - finished = seq_group.is_finished() - - if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] - assembled_seq_group = group.maybe_assemble_group(seq_group) - if finished: - group.finish_seq(seq_group) - if assembled_seq_group is None: - return None - - # clear finished seq in seq_id_to_seq_group - if len(group.to_be_finished) == 0: - for sub_request_id in list(group.seq_id_to_index.keys()): - if sub_request_id in seq_id_to_seq_group: - del seq_id_to_seq_group[sub_request_id] - - return cls.from_seq_group(assembled_seq_group, use_cache, - seq_id_to_seq_group) - - sampling_params = seq_group.sampling_params - if sampling_params is None: - raise ValueError( - "Sampling parameters are missing for a CompletionRequest.") - - if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( - not finished): - return None - - # Init cache (if needed) - if use_cache and seq_group.cached_request_output is None: - seq_group.cached_request_output = RequestOutput( # type: ignore - request_id="", - prompt=None, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[], - finished=False) - - top_n_seqs = seq_group.get_seqs() - - # Create the outputs. - # NOTE: We need omit logprobs here explicitly because the sequence - # always has the logprobs of the sampled tokens even if the - # logprobs are not requested. - include_logprobs = sampling_params.logprobs is not None - text_buffer_length = sampling_params.output_text_buffer_length - delta = sampling_params.output_kind == RequestOutputKind.DELTA - - outputs = [] - include_prompt = True - # num_cached_tokens should be the same for all the sequences - num_cached_tokens = None - for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) - - output_token_ids = seq.get_output_token_ids_to_return(delta) - num_output_tokens = 1 if isinstance(output_token_ids, - int) else len(output_token_ids) - num_cached_tokens = seq.data.get_num_cached_tokens() - - output_logprobs = seq.output_logprobs if include_logprobs else None - - if delta: - # Slice logprobs delta if applicable - if output_logprobs: - # num_output_tokens can be 0 when n > 1 and request finishes - # before the others - if num_output_tokens > 0: - output_logprobs = output_logprobs[-num_output_tokens:] - else: - output_logprobs = None - # Don't include prompt if this is after the first output - # containing decode token ids - if include_prompt and seq.get_output_len() > num_output_tokens: - include_prompt = False - - if use_cache: - # Get cached output object - cached_outputs = seq_group.cached_request_output.outputs # type: ignore - if i >= len(cached_outputs): - cached_outputs.append( - CompletionOutput(index=i, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, - finish_reason=None, - stop_reason=None)) - output = cached_outputs[i] - - # Init cached output object - assert output.index == i - output.text = output_text - - if isinstance(output_token_ids, int): - output.token_ids.clear() - output.token_ids.append(output_token_ids) - else: - output.token_ids = output_token_ids - - output.cumulative_logprob = seq.get_cumulative_logprob() \ - if include_logprobs else None - output.logprobs = output_logprobs - output.finish_reason = SequenceStatus.get_finished_reason( - seq.status) - output.stop_reason = seq.stop_reason - - else: - output = CompletionOutput( - top_n_seqs.index(seq), output_text, [output_token_ids] - if isinstance(output_token_ids, int) else output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - output_logprobs, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) - - outputs.append(output) - - # Every sequence in the sequence group should have the same prompt. - if include_prompt: - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - else: - prompt = None - prompt_token_ids = None - encoder_prompt = None - encoder_prompt_token_ids = None - prompt_logprobs = None - finished_time = time.time() if finished else None - seq_group.set_finished_time(finished_time) - - init_kwargs = { - "request_id": seq_group.request_id, - "prompt": prompt, - "prompt_token_ids": prompt_token_ids, - "prompt_logprobs": prompt_logprobs, - "outputs": outputs, - "finished": finished, - "metrics": seq_group.metrics, - "lora_request": seq_group.lora_request, - "encoder_prompt": encoder_prompt, - "encoder_prompt_token_ids": encoder_prompt_token_ids, - "num_cached_tokens": num_cached_tokens, - "multi_modal_placeholders": seq_group.multi_modal_placeholders - } - - if use_cache: - request_output = seq_group.cached_request_output - request_output.__init__(**init_kwargs) # type: ignore - else: - request_output = cls(**init_kwargs) # type: ignore - - return request_output - def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " @@ -371,19 +204,6 @@ class PoolingRequestOutput(Generic[_O]): self.finished = finished self.outputs = outputs - @staticmethod - def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput": - pooled_data = seq_group.pooled_data - assert pooled_data is not None - - data = pooled_data.to(dtype=torch.float32, device="cpu") - output = PoolingOutput(data) - prompt_token_ids = seq_group.prompt_token_ids - finished = seq_group.is_finished() - - return PoolingRequestOutput(seq_group.request_id, output, - prompt_token_ids, finished) - def __repr__(self): return (f"{type(self).__name__}(request_id={self.request_id!r}, " f"outputs={self.outputs!r}, " @@ -391,19 +211,6 @@ class PoolingRequestOutput(Generic[_O]): f"finished={self.finished})") -class RequestOutputFactory: - - @staticmethod - def create(seq_group: SequenceGroup, - seq_id_to_seq_group: dict[str, SequenceGroupBase], - use_cache: bool = False): - if seq_group.pooled_data is not None: - return PoolingRequestOutput.from_seq_group(seq_group) - else: - return RequestOutput.from_seq_group(seq_group, use_cache, - seq_id_to_seq_group) - - @dataclass class EmbeddingOutput: """The output data of one embedding output of a request. diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 544e091491bf5..cd41832bc2ea4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -126,10 +126,6 @@ class CpuPlatform(Platform): """ torch.cpu.set_device(device) - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 87d8f2b7481bb..c263e2afe83b0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -96,16 +96,6 @@ class CudaPlatformBase(Platform): def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager and not envs.VLLM_USE_V1: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - return False - return True - @classmethod def is_fully_connected(cls, device_ids: list[int]) -> bool: raise NotImplementedError diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 53fc762dce540..c43580ac5da13 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -275,13 +275,6 @@ class Platform: """Get the total memory of a device in bytes.""" raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - """ - Check if the current platform supports async output. - """ - raise NotImplementedError - @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4f540fe965e22..dce2924ac7a9d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -310,16 +310,6 @@ class RocmPlatform(Platform): device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager and not envs.VLLM_USE_V1: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - return False - return True - @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: from vllm.config.compilation import CUDAGraphMode diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4e4db116abca0..9852d948bc4b8 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -75,10 +75,6 @@ class TpuPlatform(Platform): def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 67ef058df10f1..4d3bef4b42947 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -98,10 +98,6 @@ class XPUPlatform(Platform): device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True - @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 2f9ebe531cbb1..41136f738c286 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -353,8 +353,8 @@ class layerwise_profile(profile): Args: num_running_seqs (Optional[int], optional): When given, - num_running_seqs will be passed to LayerProfileResults for metadata - update. Defaults to None. + num_running_seqs will be passed to LayerProfileResults + for metadata update. Defaults to None. """ super().__init__( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], diff --git a/vllm/sequence.py b/vllm/sequence.py index 24114c0bb792e..a6c194fbac0b2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,28 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sequence and its related classes.""" -import copy -import enum -from abc import ABC, abstractmethod -from array import array -from collections import defaultdict -from collections.abc import Mapping -from collections.abc import Sequence as GenericSequence -from dataclasses import dataclass, field -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union import msgspec import torch -from vllm.inputs import SingletonInputs -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import RequestOutputKind, SamplingParams - if TYPE_CHECKING: - from vllm.lora.request import LoRARequest from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) else: @@ -34,50 +19,6 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 -def array_full(token_id: int, count: int): - """[`array`][] equivalent of [numpy.full][].""" - return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - - -class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered - # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: - return status > SequenceStatus.SWAPPED - - @staticmethod - def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: - if status == SequenceStatus.FINISHED_STOPPED: - finish_reason = "stop" - elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: - finish_reason = "length" - elif status == SequenceStatus.FINISHED_ABORTED: - finish_reason = "abort" - elif status == SequenceStatus.FINISHED_IGNORED: - # The ignored sequences are the sequences whose prompt lengths - # are longer than the model's length cap. Therefore, the stop - # reason should also be "length" as in OpenAI API. - finish_reason = "length" - else: - finish_reason = None - return finish_reason - - -class SequenceStage(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - - @dataclass class RequestMetrics: """Metrics associated with a request. @@ -107,971 +48,12 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class SequenceDataDelta( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta SequenceData to send to workers per step.""" - # A new token to be appended to existing SequenceData. - new_output_token_ids: list[int] - # Overwriting existing `cumulative_logprob` - new_cumulative_logprob: float - # Overwriting existing `num_computed_tokens`. - new_num_computed_tokens: int - # Overwriting existing `stage`. - new_stage: SequenceStage - - -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence.""" - # NOTE: we cannot use Union[list, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - - _prompt_embeds: Optional[torch.Tensor] = None - _output_embeds: Optional[torch.Tensor] = None - - ### The below fields should not be passed as an argument ### - _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: tuple[int, - ...] = msgspec.field(default_factory=tuple) - # The number of tokens that are computed (that run against the model). - _num_computed_tokens: int = 0 - # The number of tokens with prefix cache hit. - _num_cached_tokens: int = 0 - _stage: SequenceStage = SequenceStage.PREFILL - _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) - _cached_all_token_embeds: Optional[torch.Tensor] = None - - # It is used to get delta input. It is reset when `get_delta_and_reset` - # is called. - _new_appended_tokens: list[int] = msgspec.field(default_factory=list) - - # It is used to compute mrope_position_ids. - _mrope_position_delta: Optional[int] = None - - @staticmethod - def from_prompt_token_counts( - *token_counts: tuple[int, int]) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - by concatenating prompt token sequences. - - Each tuple represents one token sequence, expressed in the form - `(token_id, count)`. - """ - if len(token_counts) == 0: - return SequenceData.from_seqs([]) - - prompt_token_ids_arr = reduce( - array.__iadd__, - (array_full(token_id, count) for token_id, count in token_counts), - ) - - return SequenceData(prompt_token_ids_arr) - - @staticmethod - def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, - *, - prompt_embeds: Optional[torch.Tensor] = None, - ) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - from prompt and output token sequences. - """ - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) - - if output_token_ids is None: - return SequenceData(prompt_token_ids_arr, - _prompt_embeds=prompt_embeds) - - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) - - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr, - _prompt_embeds=prompt_embeds) - - def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: tuple[int, ...] = tuple( - self._prompt_token_ids) - self._update_cached_all_tokens() - if self._prompt_embeds is not None: - self._update_cached_all_token_embeds() - - def _update_cached_all_tokens(self): - assert isinstance(self._prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + - self._output_token_ids) - - def _update_cached_all_token_embeds(self): - assert isinstance(self._prompt_embeds, torch.Tensor) - self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds - if self._output_embeds is not None: - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, self._output_embeds), dim=0) - - @property - def cumulative_logprob(self) -> float: - """The cumulative log probability of the output.""" - return self._cumulative_logprob - - @property - def prompt_token_ids(self) -> tuple[int, ...]: - """The token IDs of the prompt.""" - return self._prompt_token_ids_tuple - - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - - @property - def output_token_ids(self) -> tuple[int, ...]: - """The token IDs of the output.""" - return tuple(self._output_token_ids) - - @output_token_ids.setter - def output_token_ids(self, - new_output_token_ids: GenericSequence[int]) -> None: - self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids) - self._update_cached_all_tokens() - - @property - def output_embeds(self) -> Optional[torch.Tensor]: - return self._output_embeds - - @output_embeds.setter - def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: - self._output_token_embeds = new_output_token_embeds - self._update_cached_all_token_embeds() - - @property - def output_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - assert isinstance(self._output_token_ids, array) - return self._output_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: - self._prompt_embeds = prompt_embeds - self._update_cached_all_token_embeds() - - @property - def mrope_position_delta(self) -> Optional[int]: - return self._mrope_position_delta - - @mrope_position_delta.setter - def mrope_position_delta(self, new_mrope_position_delta): - self._mrope_position_delta = new_mrope_position_delta - - def append_token_id(self, - token_id: int, - logprob: float, - token_embed: Optional[torch.Tensor] = None) -> None: - self._output_token_ids.append(token_id) - self._new_appended_tokens.append(token_id) - self._cached_all_token_ids.append(token_id) - self._cumulative_logprob += logprob - if token_embed is not None: - # Do not pass in with batch or sequence dimensions - assert token_embed.ndim == 1 - token_embed = token_embed.detach().cpu().unsqueeze(0) - if self._output_embeds is None: - self._output_embeds = token_embed - else: - self._output_embeds = torch.cat( - (self._output_embeds, token_embed), dim=0) - assert self._cached_all_token_embeds is not None - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, - token_embed.to(device=self._cached_all_token_embeds.device)), - dim=0) - - def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) - - def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) - - def get_output_len(self) -> int: - return len(self._output_token_ids) - - def get_token_ids(self) -> list[int]: - return self._cached_all_token_ids - - def get_token_embeddings(self) -> Optional[torch.Tensor]: - return self._cached_all_token_embeds - - def get_prefix_token_ids( - self, num_tokens: int - ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) - - def get_num_computed_tokens(self) -> int: - """Return the number of prefill tokens that are already computed.""" - return self._num_computed_tokens - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) - # If all tokens are computed, it means it is in decoding phase. - if self.get_num_uncomputed_tokens() == 0: - self._stage = SequenceStage.DECODE - - def get_num_cached_tokens(self) -> int: - """Return the number of tokens with prefix cache hit.""" - return self._num_cached_tokens - - def update_num_cached_tokens(self, num_cached_tokens: int): - """Update the number of tokens with prefix cache hit.""" - self._num_cached_tokens = num_cached_tokens - - def reset_state_for_recompute(self) -> None: - """Reset the number of computed tokens from this sequence. It is - supposed to be called when a sequence needs to be started from - the beginning again (e.g., sequence is preempted). - """ - self._num_computed_tokens = 0 - self._stage = SequenceStage.PREFILL - self._new_appended_tokens = [] - - def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefill tokens that are not computed.""" - # we use `get_len()` which includes prompt_len + output_len instead - # of prompt_len here. This is because during recompute we need to - # prefill for both prompt and output. - return self.get_len() - self.get_num_computed_tokens() - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.prompt_token_ids - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.output_token_ids - - def get_delta_and_reset(self) -> SequenceDataDelta: - delta = SequenceDataDelta(self._new_appended_tokens, - self._cumulative_logprob, - self.get_num_computed_tokens(), self.stage) - # Reset delta state. - self._new_appended_tokens = [] - return delta - - def apply_delta(self, delta: SequenceDataDelta): - self._num_computed_tokens = delta.new_num_computed_tokens - self._cumulative_logprob = delta.new_cumulative_logprob - self._stage = delta.new_stage - self._output_token_ids.extend(delta.new_output_token_ids) - self._cached_all_token_ids.extend(delta.new_output_token_ids) - - @property - def stage(self) -> SequenceStage: - return self._stage - - def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds.shape=" - f"{getattr(self._prompt_embeds, 'shape', None)}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") - - -class Sequence: - """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the - [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only) - or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] - (for encoder-decoder) instance passed in through the `inputs` - constructor argument. - - Args: - seq_id: The ID of the sequence. - inputs: The inputs of the sequence. - block_size: The block size of the sequence. Should be the same as the - block size used by the block manager and cache engine. - eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. - lora_request: LoRA request. - """ - - def __init__( - self, - seq_id: int, - inputs: SingletonInputs, - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.seq_id = seq_id - self.inputs = inputs - self.block_size = block_size - self.eos_token_id = eos_token_id - self.lora_request = lora_request - - self.data = SequenceData.from_seqs( - self.prompt_token_ids, - prompt_embeds=self.inputs["prompt_embeds"] - if self.inputs["type"] == "embeds" else None) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" - - self.status = SequenceStatus.WAITING - self.stop_reason: Union[int, str, None] = None - - # These are used to keep track of delta outputs - self._last_output_token_ids_offset: int = 0 - self._last_output_text_offset: int = 0 - - # Used for incremental detokenization - self.prefix_offset = 0 - self.read_offset = 0 - # Input + output tokens - self.tokens: Optional[list[str]] = None - - @property - def n_blocks(self) -> int: - return (self.get_len() + self.block_size - 1) // self.block_size - - @property - def prompt(self) -> Optional[str]: - if self.inputs["type"] == "embeds": - return None - return self.inputs.get("prompt") - - @property - def prompt_token_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [0] * len(self.inputs["prompt_embeds"]) - return self.inputs["prompt_token_ids"] - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"].get_data() - - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_placeholders"] - - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def get_output_text_to_return(self, buffer_length: int, - delta: bool) -> str: - """If delta is True, only new text since the last call to - this method is returned""" - - # We return the full output text if the sequence is finished. - truncate = buffer_length and not self.is_finished() - if not delta: - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) - length = len(self.output_text) - if truncate: - length -= buffer_length - last_offset = self._last_output_text_offset - if last_offset < length: - self._last_output_text_offset = length - return self.output_text[last_offset:length] - return "" - - def get_output_token_ids_to_return( - self, delta: bool) -> Union[GenericSequence[int], int]: - """If delta is True, only new tokens since the last call to - this method are returned""" - if not delta: - return self.get_output_token_ids() - - output_len = self.get_output_len() - - # Get the number of new tokens - num_new_tokens = output_len - self._last_output_token_ids_offset - self._last_output_token_ids_offset = output_len - - # Return new tokens - if num_new_tokens == 1: - # Optimization for single decode token case - # (which is what we have most of the time) - return self.data._cached_all_token_ids[-1] - - if num_new_tokens == 0: - return [] - - return self.data._cached_all_token_ids[-num_new_tokens:] - - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size - - # Compute the number of tokens in the sequence - # TODO: The current hashing function is O(L^2). We should optimize - # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) - - def extra_hash(self) -> Optional[int]: - """ - This function computes an extra hash for a sequence, specifically - designed for prefix caching mode. The final sequence hash is determined - by applying token_ids from the sequence's blocks. - """ - if self.lora_int_id == 0: - return None - - # NOTE: If there are additional factors influencing the block aside from - # token_ids, include them as input parameters to the hash. - return hash(self.lora_int_id) - - def num_hashed_tokens_of_block(self, logical_idx: int): - return logical_idx * self.block_size + self.block_size - - def reset_state_for_recompute(self): - """Reset the sequence states for recomputation.""" - self.data.reset_state_for_recompute() - - def append_token_id(self, - token_id: int, - logprobs: dict[int, Logprob], - token_embed: Optional[torch.Tensor] = None) -> None: - assert token_id in logprobs - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob, - token_embed) - - def get_len(self) -> int: - return self.data.get_len() - - def get_prompt_len(self) -> int: - return self.data.get_prompt_len() - - def get_output_len(self) -> int: - return self.data.get_output_len() - - def get_token_ids(self) -> list[int]: - return self.data.get_token_ids() - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.data.get_prompt_token_ids() - - def get_last_token_id(self) -> int: - return self.data.get_last_token_id() - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.data.get_output_token_ids() - - def get_cumulative_logprob(self) -> float: - return self.data.cumulative_logprob - - def is_finished(self) -> bool: - return SequenceStatus.is_finished(self.status) - - def fork(self, new_seq_id: int) -> "Sequence": - new_seq = copy.deepcopy(self) - new_seq.seq_id = new_seq_id - return new_seq - - def get_num_new_tokens(self) -> int: - """Get the number of new tokens to be computed. - - Returns: - The new number of tokens to be computed. I.e., 1 for decode, or - the remaining prompt size for prefill. - """ - if self.data.stage == SequenceStage.DECODE: - return 1 - return self.data.get_num_uncomputed_tokens() - - def get_num_computed_tokens(self) -> int: - return self.data.get_num_computed_tokens() - - def is_prefill(self) -> bool: - return self.data.stage == SequenceStage.PREFILL - - def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={self.n_blocks})") - - -class SequenceGroupState(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Mutable state tied to a specific sequence group""" - - # for multi-step decoding - num_steps: int = 1 - current_step: int = 0 - - @property - def remaining_steps(self) -> int: - return self.num_steps - self.current_step - - -class SequenceGroup: - """A group of sequences that are generated from the same prompt. - - Args: - request_id: The ID of the request. - seqs: The list of sequences. - sampling_params: The sampling parameters used to generate the outputs. - arrival_time: The arrival time of the request. - lora_request: LoRA request. - pooling_params: The parameters used to generate the pooler - for a pooling model. - pooled_data: The extracted hidden states from a pooling model. - encoder_seq: Optional, the single encoder sequence. Should be None - unless you are working with an encoder/decoder model. - trace_headers: OpenTelemetry trace headers. - priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target - model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than - that for multi-draft SD (currently not supported). - """ - - def __init__(self, - request_id: str, - seqs: list[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - pooling_params: Optional[PoolingParams] = None, - pooled_data: Optional[torch.Tensor] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - draft_size: int = 1) -> None: - self.request_id = request_id - self.seqs = seqs - self.first_seq = seqs[0] - self.arrival_time = arrival_time - self.is_single_seq = len(seqs) == 1 - self.seqs_dict = {seq.seq_id: seq for seq in seqs} - - self.sampling_params = sampling_params - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) - self.last_token_latency = 0.0 - self.lora_request = lora_request - self.prompt_logprobs: Optional[PromptLogprobs] = None - self.state = SequenceGroupState() - self.pooling_params = pooling_params - self.pooled_data = pooled_data - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - self.priority = priority - - self.cached_request_output = None - - @property - def prompt(self) -> Optional[str]: - return self.first_seq.prompt - - @property - def prompt_token_ids(self) -> list[int]: - return self.first_seq.prompt_token_ids - - @property - def encoder_prompt(self) -> Optional[str]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt is distinct - # from the decoder's. - return (self.encoder_seq.prompt - if self.encoder_seq is not None else None) - - @property - def encoder_prompt_token_ids(self) -> Optional[list[int]]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt token ids are - # distinct from the decoder's. - return (self.encoder_seq.prompt_token_ids - if self.encoder_seq is not None else None) - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_data - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_data - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_placeholders - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_placeholders - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def set_last_token_time(self, now: float) -> None: - """Sets the last token time for Request level timings.""" - # If still in prefill phase, assertion fails. - assert not self.is_prefill(), ( - "seq_group.set_last_token_time() should not be called " - "if the seq_group is in prefill phase.") - self.last_token_latency = now - self.metrics.last_token_time - self.metrics.last_token_time = now - - def get_last_token_latency(self) -> float: - """Returns the latency of the last token.""" - assert not self.is_prefill(), ( - "seq_group.get_last_token_latency() should not be called " - "if the seq_group is in prefill phase.") - return self.last_token_latency - - def maybe_set_first_token_time(self, time: float) -> None: - """Sets the first token time for Request level timings.""" - # Note: in a case where a sequence_group is swapped and - # recomputed, the time between iterations is counted - # in TPOT, rather than recalculating TTFT (since from the ) - # POV of the user, there is simply a long generation delay. - if (self.metrics.first_token_time is None - and self.first_seq.get_output_len() == 1): - self.metrics.first_token_time = time - - def maybe_set_first_scheduled_time(self, time: float) -> None: - """Sets the first scheduled time and time in queue for Request - level timings.""" - if self.metrics.first_scheduled_time is None: - self.metrics.first_scheduled_time = time - self.metrics.time_in_queue = time - self.metrics.arrival_time - - def set_finished_time(self, time: Optional[float]) -> None: - """Sets the finished time for Request level timings.""" - self.metrics.finished_time = time - - def get_max_num_running_seqs(self) -> int: - """The maximum number of sequences running in parallel in the remaining - lifetime of the request.""" - if self.is_single_seq: - return 0 if self.first_seq.is_finished() else 1 - return self.num_seqs() - self.num_finished_seqs() - - def get_seqs( - self, - status: Optional[SequenceStatus] = None, - ) -> list[Sequence]: - if status is None: - return self.seqs - - if self.is_single_seq: - return self.seqs if self.first_seq.status == status else [] - - return [seq for seq in self.seqs if seq.status == status] - - def is_encoder_decoder(self) -> bool: - return self.encoder_seq is not None - - def get_encoder_seq(self) -> Optional[Sequence]: - return self.encoder_seq - - def get_finished_seqs(self) -> list[Sequence]: - if self.is_single_seq: - return self.seqs if self.first_seq.is_finished() else [] - - return [seq for seq in self.seqs if seq.is_finished()] - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - for seq in self.seqs: - if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) - - def get_num_uncomputed_tokens(self) -> int: - num_uncomputed_tokens = 0 - for seq in self.seqs: - if not seq.is_finished(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() - return num_uncomputed_tokens - - def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: - # Optimization. We don't need to call get_seqs if we don't need to - # filter by states. - if status is None: - return len(self.seqs) - - if self.is_single_seq: - return 1 if self.seqs[0].status == status else 0 - - return len(self.get_seqs(status)) - - def num_finished_seqs(self) -> int: - if self.is_single_seq: - return 1 if self.seqs[0].is_finished() else 0 - return len(self.get_finished_seqs()) - - def is_finished(self) -> bool: - if self.is_single_seq: - return self.first_seq.is_finished() - return all(seq.is_finished() for seq in self.seqs) - - def is_prefill(self) -> bool: - return self.first_seq.is_prefill() - - def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") - - def uses_prompt_embeds(self) -> bool: - """Returns True if the sequence group uses input embeds.""" - return any(seq.data.prompt_embeds is not None for seq in self.seqs) - - -class SequenceGroupMetadataDelta( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta of SequenceGroupMetadata. - - After sending the first SequenceGroupMetadata, vLLM scheduler - only sends delta to reduce the data payload size. - """ - seq_data_delta: dict[int, SequenceDataDelta] - request_id: str - block_tables: dict[int, list[int]] - is_prompt: bool - do_sample: bool = True - token_chunk_size: Optional[int] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - - -class SequenceGroupMetadata( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Metadata for a sequence group. Used to create `AttentionMetadata`. - - Attributes: - request_id: The ID of the request. - is_prompt: Whether the request is at prompt stage. - seq_data: The sequence data. (Seq id -> sequence data) - sampling_params: The sampling parameters used to generate the outputs. - block_tables: The block tables. (Seq id -> list of physical block - numbers) - do_sample: True if sampling is required. Sampling is not required when - e.g., prefill is chunked, and the current iteration only computes - query tokens for prefill, we don't need sampling. - pooling_params: Pooling parameters. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. - state: Internal state tied to this sequence group. - token_type_ids: Token type IDs. - multi_modal_data: Multi modal data. - multi_modal_placeholders: Multi modal placeholders. - encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - cross_block_table: Optional cross-attention block table associated - with the encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - """ - - request_id: str - is_prompt: bool - seq_data: dict[int, SequenceData] - sampling_params: Optional[SamplingParams] - block_tables: dict[int, list[int]] - do_sample: bool = True - pooling_params: Optional[PoolingParams] = None - lora_request: Optional[LoRARequest] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - multi_modal_data: Optional[MultiModalKwargs] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - encoder_seq_data: Optional[SequenceData] = None - cross_block_table: Optional[list[int]] = None - token_chunk_size: Optional[int] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - num_speculative_tokens: Optional[int] = None - - def __post_init__(self): - if self.seq_data is not None and self.token_chunk_size is None: - if self.is_prompt: - self.token_chunk_size = next(iter( - self.seq_data.values())).get_len() - else: - self.token_chunk_size = 1 - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - # Multi-Step Chunked-Prefill property - @property - def is_single_step_prompt(self) -> bool: - # do_sample is true, only when the token_chunk_size matches the - # num_uncomputed_tokens of the sequence. This indicates that - # the prompt will finish processing in a single `execute_model` - # step. - return self.is_prompt and self.do_sample - - def get_first_seq_id(self) -> int: - # This is an efficient way of fetching the seq_id when - # we know this SequenceGroup has only one sequence. - return next(iter(self.seq_data)) - - def apply_delta(self, - sequence_group_metadata_delta: SequenceGroupMetadataDelta): - for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): - self.seq_data[id].apply_delta(delta) - assert self.request_id == sequence_group_metadata_delta.request_id - self.block_tables = sequence_group_metadata_delta.block_tables - self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size - self.do_sample = sequence_group_metadata_delta.do_sample - self.is_prompt = sequence_group_metadata_delta.is_prompt - - def finish_step(self) -> None: - assert self.state is not None - assert self.state.current_step < self.state.num_steps, \ - f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa - self.state.current_step += 1 - - -class SequenceOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a sequence. - - Attributes: - parent_seq_id: The ID of the parent sequence (for forking in beam - search). - output_token: The output token ID. - logprobs: The logprobs of the output token. - (Token id -> logP(x_i+1 | x_0, ..., x_i)) - output_embed: Optional output embedding tensor. - """ - parent_seq_id: int - output_token: int - logprobs: dict[int, Logprob] - output_embed: Optional[torch.Tensor] = None - - def __repr__(self) -> str: - output_embed_shape = \ - self.output_embed.shape if self.output_embed is not None else None - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}, " - f"logprobs={self.logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutput): - raise NotImplementedError() - equal = (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token) - log_probs_equal = other.logprobs == self.logprobs - return equal and log_probs_equal - - -class SequenceGroupOutput(ABC): - """The base class for model outputs associated with a sequence group.""" - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractmethod - def __eq__(self, other: object) -> bool: - pass - - -class CompletionSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a completion sequence group.""" - __metaclass__ = SequenceGroupOutput - samples: list[SequenceOutput] - # Prompt logprob for each prompt query token. - prompt_logprobs: Optional[PromptLogprobs] - step_index: Optional[int] = 0 - - def __repr__(self) -> str: - return (f"CompletionSequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CompletionSequenceGroupOutput): - raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) - - class PoolingSequenceGroupOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg] ): """The model output associated with a pooling sequence group.""" - __metaclass__ = SequenceGroupOutput # Annotated as Any to be compatible with msgspec # The actual type is in SequenceGroup.pooled_data data: Any @@ -1161,305 +143,9 @@ class PoolerOutput( self.__class__) and self.outputs == other.outputs -def get_all_seq_ids( - seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] - - -def get_all_seq_ids_and_request_ids( - seq_group_metadata_list: list[SequenceGroupMetadata] -) -> tuple[list[int], dict[str, set[int]]]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - seq_ids: list[int] = [] - request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set) - for sg in seq_group_metadata_list: - for seq_id in sg.seq_data: - seq_ids.append(seq_id) - request_id_seq_ids_mapping[sg.request_id].add(seq_id) - return seq_ids, request_id_seq_ids_mapping - - -class HiddenStates(msgspec.Struct, array_like=True, - omit_defaults=True): # type: ignore[call-arg] - """Hidden states corresponding to in-progress sequences. - Used in speculative decoding to pass hidden states from - the target model to the proposer model. - - seq_ids are the sequence ids of each entry of the batch - dimension of the hidden_states tensor""" - # Scorer hidden states. For prefill step, it is used for hidden states of - # all tokens, whereas for decode step, it is used for last accepted tokens. - hidden_states: torch.Tensor - # The sequence group metadata list. Only needed for decode step. - seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None - # Scorer hidden states of the 2nd last token proposed by the proposer ( - # irrespective of whether it was accepted or not). Only used for cases when - # last proposed token is accepted (i.e., in case of bonus tokens). For the - # case of no bonus tokens, these are ignored. - second_last_token_hidden_states: Optional[torch.Tensor] = None - - _seq_ids: list[int] = msgspec.field(default_factory=list) - - def __post_init__(self): - if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - - @property - def seq_ids(self) -> list[int]: - return self._seq_ids - - def update(self, - hidden_states: torch.Tensor, - seq_group_metadata_list: list[SequenceGroupMetadata], - second_last_token_hidden_states: Optional[torch.Tensor] = None): - """Update hidden states from target model invocation. Only used for - decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) - self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - - if self.second_last_token_hidden_states is not None: - # Adding dummy hidden_states to this to maintain same shape - self.second_last_token_hidden_states = torch.cat([ - self.second_last_token_hidden_states, - torch.zeros_like(hidden_states) - if second_last_token_hidden_states is None else - second_last_token_hidden_states - ]) - - def prune(self, - seq_group_metadata_list: list[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids. Only used for decode steps. - """ - # Currently this prunes all seq_ids not present in - # seq_group_metadata_list which might cause problems where a sequence - # may be "paused" then "resumed" later. This should only prune sequences - # which are confirmed to be aborted. - seq_ids = get_all_seq_ids(seq_group_metadata_list) - # Only keep sequence IDs that exist in self._seq_ids - seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids] - if seq_ids != self._seq_ids: - # Batch contents changed - prune removed sequences. - index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] - self._seq_ids = seq_ids - - def expand_with_bonus_tokens( - self, seq_with_bonus_token_in_last_step: set) -> None: - """Expand hidden states for sequences with bonus tokens. This is in - alignment with `MultiStepWorker._expand_execute_model_request`.""" - if self.second_last_token_hidden_states is None \ - or not seq_with_bonus_token_in_last_step: - return - - index = [] - for seq_id in self._seq_ids: - i = self._seq_ids.index(seq_id) - if seq_id in seq_with_bonus_token_in_last_step: - index.append(i + len(self._seq_ids)) - index.append(i) - - self.hidden_states = torch.cat( - [self.hidden_states, self.second_last_token_hidden_states])[index] - - class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] omit_defaults=True): # type: ignore[call-arg] - """The model execution request, containing CPU metadata only. The LLM - engine should create an instance of this class for each request batch.""" - # The sequence group metadata list. - seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to copy. Source to dest block. - blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list) - # Virtual engine ID for pipeline parallel. - virtual_engine: int = 0 - # The number of slots for lookahead decoding. - num_lookahead_slots: int = 0 - # The number of requests in the running queue. - running_queue_size: int = 0 - # Optional hidden states from prior step. - previous_hidden_states: Optional[HiddenStates] = None - # The number of forward steps to run. - num_steps: int = 1 - # Finished request ids since last step. - finished_requests_ids: list[str] = msgspec.field(default_factory=list) - # The last sampled token ids for multi step decoding. - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None - - @property - def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.remaining_steps == 1 - - @property - def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - state = self.seq_group_metadata_list[0].state - assert state is not None - return state.current_step - - def clone( - self, seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - ) -> "ExecuteModelRequest": - """Clone the request with a new sequence group metadata list.""" - return ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=self.blocks_to_swap_in.copy(), - blocks_to_swap_out=self.blocks_to_swap_out.copy(), - blocks_to_copy=self.blocks_to_copy.copy(), - virtual_engine=self.virtual_engine, - num_lookahead_slots=self.num_lookahead_slots, - running_queue_size=self.running_queue_size, - previous_hidden_states=self.previous_hidden_states, - num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids, - last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) - - -@dataclass -class SequenceGroupBase: - group_id: str # the original request id before splitting - - assembled_seq_group: Optional[SequenceGroup] = None - - # seq id to a unique index inside this group - seq_id_to_index: dict[str, int] = field(default_factory=dict) - - # seq ids to be finished - to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict) - - # seq id to finished sequences - finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict) - - streaming: bool = False - - output_produced: bool = False - - @staticmethod - def add_request(request_id: str, engine, params, *args, **kwargs): - """When we are ready to add a request with request_id and params - into the engine, we can split the request into multiple requests. - """ - raise NotImplementedError - - def finish_seq(self, seq: SequenceGroup): - """The sequence `seq` finishes, we should record the information. - """ - del self.to_be_finished[seq.request_id] - self.finished_reqs[seq.request_id] = seq - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - """Assemble the sequence group, for producing the final - output, or adding request in the engine again. - """ - raise NotImplementedError - - -class ParallelSampleSequenceGroup(SequenceGroupBase): - - @staticmethod - def add_request(request_id: str, engine, params, **kwargs): - original_params = params - group = ParallelSampleSequenceGroup(request_id) - seqs = [] - for i in range(original_params.n): - request_id_i = f"{request_id}_parallel_sample_{i}" - group.seq_id_to_index[request_id_i] = i - params = original_params.clone() - params.n = 1 - if params.seed is not None: - params.seed += i - seq_group = engine._add_processed_request( - request_id_i, - params=params, - **kwargs, - ) # type: ignore - assert seq_group is not None - engine.seq_id_to_seq_group[request_id_i] = group - group.to_be_finished[request_id_i] = seq_group - seqs.append(seq_group.seqs[0]) - - # for parallel sampling, the `assembled_seq_group` is always - # available, since we have all the sequences ready, and they - # will not change. - group.assembled_seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - arrival_time=seq_group.arrival_time, - sampling_params=original_params, - lora_request=seq_group.lora_request, - pooling_params=seq_group.pooling_params, - pooled_data=seq_group.pooled_data, - encoder_seq=seq_group.encoder_seq, - trace_headers=seq_group.trace_headers, - priority=seq_group.priority, - ) - - group.streaming = params.output_kind == RequestOutputKind.DELTA - group.output_produced = False - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - - # in the streaming mode, we will return the assembled sequence - # for the first remaining sequence, and then return None for the - # rest of sequences - if self.streaming: - first_remaining_id = next(iter(self.to_be_finished)) - if seq_group.request_id == first_remaining_id: - return self.assembled_seq_group - return None - - # in the non-streaming mode, we will return the assembled sequence - # when the last sequences finishes, and then return None for the - # rest of the time - if (len(self.to_be_finished) == 1 - and seq_group.request_id in self.to_be_finished - and seq_group.is_finished()): - assert self.assembled_seq_group is not None - params = self.assembled_seq_group.sampling_params - assert isinstance(params, SamplingParams) - if not self.output_produced: - self.output_produced = True - if params._real_n is not None: - # Get the top-n sequences. - n = params._real_n or params.n - seqs = self.assembled_seq_group.seqs - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - self.assembled_seq_group.seqs = top_n_seqs - return self.assembled_seq_group - if self.output_produced: - return None - return None + # Placeholder. Remove. + pass diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cafc43f6b7673..9eed466788661 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: return config -def maybe_override_with_speculators_target_model( +def maybe_override_with_speculators( model: str, tokenizer: str, trust_remote_code: bool, revision: Optional[str] = None, + vllm_speculative_config: Optional[dict[str, Any]] = None, **kwargs, -) -> tuple[str, str]: +) -> tuple[str, str, Optional[dict[str, Any]]]: """ - If running a speculators config, override running model with target model + Resolve model configuration when speculators are detected. + + Checks if the provided model is a speculators model and if so, extracts + the target model configuration and builds the speculative config. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + trust_remote_code: Whether to trust remote code + revision: Model revision + vllm_speculative_config: Existing vLLM speculative config + + Returns: + Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ is_gguf = check_gguf_file(model) if is_gguf: @@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model( token=_get_hf_token(), **kwargs, ) - spec_config = config_dict.get("speculators_config", None) - # Return the target model - if spec_config is not None: - model = tokenizer = spec_config["verifier"]["name_or_path"] - return model, tokenizer + speculators_config = config_dict.get("speculators_config") + + if speculators_config is None: + # No speculators config found, return original values + return model, tokenizer, vllm_speculative_config + + # Speculators format detected - process overrides + from vllm.transformers_utils.configs.speculators.base import ( + SpeculatorsConfig) + + vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config( + config_dict=config_dict) + + # Set the draft model to the speculators model + vllm_speculative_config["model"] = model + + # Override model and tokenizer with the verifier model from config + verifier_model = speculators_config["verifier"]["name_or_path"] + model = tokenizer = verifier_model + + return model, tokenizer, vllm_speculative_config def get_config( @@ -524,10 +554,10 @@ def get_config( else: raise ValueError( "Could not detect config format for no config file found. " - "With config_format 'auto', ensure your model has either" - "config.json (HF format) or params.json (Mistral format)." - "Otherwise please specify your_custom_config_format" - "in engine args for customized config parser") + "With config_format 'auto', ensure your model has either " + "config.json (HF format) or params.json (Mistral format). " + "Otherwise please specify your_custom_config_format " + "in engine args for customized config parser.") except Exception as e: error_message = ( diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index d5ca2c7b4751a..3f50638f16b53 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -74,8 +74,7 @@ class JAISConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). - scale_attn_by_inverse_layer_idx - (`bool`, *optional*, defaults to `False`): + scale_attn_by_inverse_layer_idx (`bool`, *optional*, default `True`): Whether to additionally scale attention weights by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index d7c16e180c709..53128b4eecb03 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -24,6 +24,12 @@ class SpeculatorsConfig(PretrainedConfig): config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + vllm_config = cls.extract_vllm_speculative_config(config_dict) + return cls(**vllm_config) + + @classmethod + def extract_vllm_speculative_config( + cls, config_dict: dict[str, Any]) -> dict[str, Any]: speculators_model_type = config_dict.get("speculators_model_type") if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: raise ValueError( @@ -34,11 +40,12 @@ class SpeculatorsConfig(PretrainedConfig): # TODO: @dsikka - use speculators pydantic model to validate cls.validate_speculators_config(config_dict=config_dict) # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) + vllm_config = cls.build_vllm_speculative_config( + config_dict=config_dict) # Apply anything specific to the supported algorithm algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] algo_updater(config_dict=config_dict, vllm_config=vllm_config) - return cls(**vllm_config) + return vllm_config @classmethod def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: @@ -60,32 +67,45 @@ class SpeculatorsConfig(PretrainedConfig): "'transformer_layer_config' must be a dictionary if provided") @classmethod - def convert_speculators_to_vllm( + def build_vllm_speculative_config( cls, config_dict: dict[str, Any]) -> dict[str, Any]: """ - Convert speculators config format to vLLM format. - - This method handles the translation of field names and structure - between speculators and vLLM formats. - - Returns: - Dictionary with vLLM-compatible configuration - """ - # Currently we only support one proposal method - spec_config = config_dict["speculators_config"] - first_method = spec_config.get("proposal_methods")[0] - num_lookahead_tokens = first_method.get("speculative_tokens") + Build vLLM-compatible speculative configuration from speculators format. - if num_lookahead_tokens is None: + This method extracts and transforms speculative configuration from the + speculators format into the structure expected by vLLM. + + Args: + config_dict: Configuration dictionary in speculators format + + Returns: + Dictionary with vLLM-compatible speculative configuration + """ + # Extract speculators configuration + spec_config = config_dict["speculators_config"] + + # Currently we only support one proposal method + proposal_methods = spec_config.get("proposal_methods") + if not proposal_methods: + raise ValueError("No proposal methods found in speculators config") + + first_method = proposal_methods[0] + num_speculative_tokens = first_method.get("speculative_tokens") + + if num_speculative_tokens is None: raise ValueError( "Missing 'speculative_tokens' in proposal method. " f"Got: {first_method}") - # Build base vLLM config + # Build base vLLM speculative configuration vllm_config = { "method": config_dict.get("speculators_model_type"), - "num_lookahead_tokens": num_lookahead_tokens, + "num_speculative_tokens": num_speculative_tokens, "target_model": spec_config.get("verifier")["name_or_path"] } - vllm_config.update(config_dict["transformer_layer_config"]) + + # Merge transformer layer configuration if present + transformer_config = config_dict.get("transformer_layer_config", {}) + vllm_config.update(transformer_config) + return vllm_config diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py deleted file mode 100644 index e2d2846a28073..0000000000000 --- a/vllm/transformers_utils/detokenizer.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -from vllm.logprobs import Logprob -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, - SequenceGroup) - -from .detokenizer_utils import (convert_prompt_ids_to_tokens, - detokenize_incrementally) -from .tokenizer import AnyTokenizer - - -class Detokenizer: - """Provides methods to decode the output of a model into text.""" - - def __init__(self, tokenizer: AnyTokenizer): - self.tokenizer = tokenizer - - def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, - prompt_logprobs: list[Optional[dict[ - int, Logprob]]], - position_offset: int) -> None: - """Decodes the logprobs for the prompt of a sequence group. - - Args: - seq_group: The sequence group to decode. - prompt_logprobs: The logprobs to decode. - position_offset: Offset of the first index of the logprobs - relative to the start of the sequence (for chunked prefill). - - Returns: - The prompt logprobs with the decoded tokens. - """ - prms = seq_group.sampling_params - assert prms is not None - - # We can pick any sequence for the prompt. - seq = seq_group.get_seqs()[0] - # Only prompt, without the generated token. - all_token_ids = seq.get_token_ids() - prompt_token_ids = all_token_ids[:-1] - prefix_offset = 0 - read_offset = 0 - next_iter_prefix_offset = 0 - next_iter_read_offset = 0 - next_iter_tokens: list[str] = [] - prev_tokens = None - - for token_position_in_logprob, prompt_logprobs_for_token in enumerate( - prompt_logprobs): - - # Absolute token position equals the index in the logprobs - # list plus the offset of the entire logprobs list relative - # to the start of the sequence. - token_position = token_position_in_logprob + position_offset - if not prompt_logprobs_for_token: - continue - for token_id, sample_logprob in prompt_logprobs_for_token.items(): - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - prompt_token_ids_with_token = ( - prompt_token_ids[:token_position] + [token_id]) - (new_tokens, new_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=prompt_token_ids_with_token, - prev_tokens=prev_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - - sample_logprob.decoded_token = new_text - - # Use the offsets & prev tokens corresponding to - # real tokens to ensure detokenization is consistent - # actual with prompt. - if token_id == all_token_ids[token_position]: - next_iter_prefix_offset = new_prefix_offset - next_iter_read_offset = new_read_offset - next_iter_tokens = new_tokens - - # Advance to the next token position. - prefix_offset = next_iter_prefix_offset - read_offset = next_iter_read_offset - if prev_tokens is None: - prev_tokens = next_iter_tokens.copy() - else: - prev_tokens.extend(next_iter_tokens) - - def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> int: - """Decodes the new token for a sequence. In-place operation. - - Args: - seq: The sequence to decode. - prms: The sampling parameters used to generate the sequence. - - Returns: - The number of characters added to the output text. - """ - all_input_ids = seq.get_token_ids() - token_id_generated_this_iteration = all_input_ids[-1] - - # Convert prompt token IDs to tokens if necessary. - # Do it here so that we don't have to repeat this - # computation for each logprob. - if seq.tokens is None: - (seq.tokens, seq.prefix_offset, - seq.read_offset) = convert_prompt_ids_to_tokens( - tokenizer=self.tokenizer, - prompt_ids=all_input_ids[:-1], - skip_special_tokens=prms.skip_special_tokens, - ) - - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=all_input_ids, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - - # Decode logprobs - logprobs = seq.output_logprobs[-1] - if logprobs: - previous_tokens = all_input_ids[:-1] - for token_id, sample_logprob in logprobs.items(): - # If the token was generated this iteration, - # use the provided text. - if token_id == token_id_generated_this_iteration: - sample_logprob.decoded_token = new_decoded_token_text - continue - - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - all_input_ids_with_logprob = previous_tokens + [token_id] - (_, new_text, _, _) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=all_input_ids_with_logprob, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - sample_logprob.decoded_token = new_text - - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_decoded_token_text - - return len(new_decoded_token_text) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index fd1c0af31269c..968bba664f0a9 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3216,7 +3216,7 @@ def cprofile_context(save_file: Optional[str] = None): Args: save_file: path to save the profile result. "1" or - None will result in printing to stdout. + None will result in printing to stdout. """ import cProfile @@ -3273,7 +3273,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool: and getattr(cfg.attn_config, "alibi", False))))) -def sha256(input) -> bytes: +def sha256(input: Any) -> bytes: """Hash any picklable Python object using SHA-256. The input is serialized using pickle before hashing, which allows @@ -3290,7 +3290,7 @@ def sha256(input) -> bytes: return hashlib.sha256(input_bytes).digest() -def sha256_cbor(input) -> bytes: +def sha256_cbor(input: Any) -> bytes: """ Hash objects using CBOR serialization and SHA-256. diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 662d3984554ad..c3358bfa74e91 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -9,7 +9,7 @@ import torch import torch._dynamo.decorators import torch.nn.functional as F from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - _score_mod_signature, + _score_mod_signature, and_masks, create_block_mask, flex_attention) @@ -292,6 +292,7 @@ class FlexAttentionMetadata: q_block_size: int = 16 kv_block_size: int = 16 transformed_score_mod: Optional[_score_mod_signature] = None + sliding_window: Optional[int] = None def _convert_physical_to_logical( self, @@ -380,6 +381,53 @@ class FlexAttentionMetadata: return final_mask_mod + def get_sliding_window_mask_mod(self) -> _mask_mod_signature: + """Creates the sliding window mask_mod function for FlexAttention. + + Note that the sliding window mask here is bidirectional, we need + to mask it with the bidirectional/causal mask for encoder/decoder. + """ + + if self.sliding_window is None: + raise ValueError( + "sliding_window must be set for sliding window attention") + + def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor, + q_idx: torch.Tensor, kv_idx: torch.Tensor): + return torch.abs(q_idx - kv_idx) < self.sliding_window + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + self.doc_ids, q_idx, physical_kv_idx) + return torch.where( + is_valid, + sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod if self.causal else sliding_window_mask_mod + + def get_mask_mod(self): + # Stage-1: initialize the base mask_mod + # (causal mask for decoder or bidirectional mask for encoder) + if self.causal: + mask_mod = self.get_causal_mask_mod() + else: + mask_mod = self.get_bidirectional_mask_mod() + # stage-2: add external mask_mod for special attention during + # forwarding runtime to create the combined mask_mod. + if self.sliding_window is not None: + # Add sliding window mask for sliding window attention + sliding_window_mask_mod = self.get_sliding_window_mask_mod() + mask_mod = and_masks(mask_mod, sliding_window_mask_mod) + return mask_mod + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: """Creates the transformed score_mod function for FlexAttention. @@ -472,12 +520,9 @@ class FlexAttentionMetadata: return BlockMask.from_kv_blocks(**block_mask_kwargs) def build_block_mask(self) -> BlockMask: - if self.causal: - mask_mod = self.get_causal_mask_mod() - kv_len = self.total_cache_tokens - else: - mask_mod = self.get_bidirectional_mask_mod() - kv_len = self.num_actual_tokens + mask_mod = self.get_mask_mod() + kv_len = (self.total_cache_tokens + if self.causal else self.num_actual_tokens) return create_block_mask_compiled( mask_mod, None, @@ -498,11 +543,7 @@ class FlexAttentionMetadata: self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - if self.causal: - self.mask_mod = self.get_causal_mask_mod() - else: - self.mask_mod = self.get_bidirectional_mask_mod() - + self.mask_mod = self.get_mask_mod() self.transformed_score_mod = self.get_transformed_score_mod() if self.direct_build and self.causal: @@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder( class FlexAttentionImpl(AttentionImpl): - sliding_window: Optional[tuple[int, int]] + sliding_window: Optional[int] alibi_slopes: Optional[torch.Tensor] logits_soft_cap: Optional[float] @@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl): "FlexAttention does not support alibi slopes yet.") else: self.alibi_slopes = None - if sliding_window is not None: - raise NotImplementedError( - "FlexAttention does not support sliding window yet.") - else: - self.sliding_window = (-1, -1) + + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap if self.logits_soft_cap is not None: @@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens + if attn_metadata.sliding_window != self.sliding_window: + attn_metadata.sliding_window = self.sliding_window + if attn_metadata.direct_build: + # TODO: Support skipping the computation of sliding window + # in direct block mask building code path. + logger.warning_once( + "Using direct block mask building with sliding window, " + "which is suboptimal now. Performance may be degraded.") + # update mask mod in attention metadata + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + attn_metadata.block_mask = ( + attn_metadata._build_block_mask_direct()) + else: + attn_metadata.block_mask = attn_metadata.build_block_mask() + if not attn_metadata.causal: assert self.attn_type == AttentionType.ENCODER_ONLY diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3ccd00121f8ed..47a41322c423c 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1229,8 +1229,8 @@ def get_kv_cache_configs(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. - available_memory: Memory available for KV cache in bytes for each - worker. + available_memory: Memory available for KV cache in bytes for each + worker. Returns: The generated KVCacheConfigs for each worker. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 73165c7e4c0ad..757baecea9ce0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -601,11 +601,7 @@ class AsyncLLM(EngineClient): async def is_tracing_enabled(self) -> bool: return self.observability_config.otlp_traces_endpoint is not None - async def do_log_stats( - self, - scheduler_outputs=None, - model_output=None, - ) -> None: + async def do_log_stats(self) -> None: if self.logger_manager: self.logger_manager.log() diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 8aa36d6a439c1..0f993a74c8103 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -9,7 +9,6 @@ from tokenizers import Tokenizer from tokenizers.decoders import DecodeStream from transformers import PreTrainedTokenizerFast -from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) @@ -129,7 +128,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # 2) Evaluate stop strings. stop_string = None if self.stop and len(self.output_token_ids) > self.min_tokens: - stop = StopChecker.check_stop_strings( + stop = check_stop_strings( output_text=self.output_text, new_char_count=len(self.output_text) - stop_check_offset, stop=self.stop, @@ -309,3 +308,42 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): self.read_offset = read_offset return decoded_text + + +def check_stop_strings( + output_text: str, + new_char_count: int, + stop: list[str], + include_in_output: bool, +) -> Optional[tuple[str, int]]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. + """ + if not new_char_count or not stop: + return None + + for stop_str in stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = output_text.find(stop_str, + 1 - new_char_count - stop_string_len) + if stop_index == -1: + continue + + if include_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(output_text): + # No truncation required. + return stop_str, -1 + + # Truncate the output text to either the beginning + # or end of the stop string. + return stop_str, stop_index + return None diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 3aa373f12b609..2aa732f34bcc8 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing +import os import pickle import queue import signal @@ -19,6 +20,7 @@ from threading import Thread from typing import Any, Callable, Optional, Union, cast import cloudpickle +import torch import vllm.envs as envs from vllm.config import VllmConfig @@ -28,14 +30,12 @@ from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_pp_group, get_tp_group) -from vllm.executor.multiproc_worker_utils import ( - set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import (decorate_logs, get_distributed_init_method, - get_loopback_ip, get_mp_context, get_open_port, - set_process_title) +from vllm.utils import (_maybe_force_spawn, decorate_logs, + get_distributed_init_method, get_loopback_ip, + get_mp_context, get_open_port, set_process_title) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.utils import get_and_update_mm_cache @@ -67,8 +67,8 @@ class MultiprocExecutor(Executor): f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"_parallel_size ({pp_parallel_size}). ") - # Set multiprocessing envs that are common to V0 and V1 - set_multiprocessing_worker_envs(self.parallel_config) + # Set multiprocessing envs + set_multiprocessing_worker_envs() # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address @@ -698,3 +698,29 @@ class WorkerProc: process_name += f"_EP{ep_rank}" set_process_title(name=process_name) decorate_logs(process_name) + + +def set_multiprocessing_worker_envs(): + """ Set up environment variables that should be used when there are workers + in a multiprocessing environment. This should be called by the parent + process before worker processes are created""" + + _maybe_force_spawn() + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3d5e59addfcfa..ced5c7a970388 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -351,17 +351,17 @@ def generate_uniform_probs( without a seed. Args: - num_tokens : int + num_tokens: int Total number of tokens. - num_draft_tokens : List[List[int]] + num_draft_tokens: List[List[int]] Number of draft tokens per request. - generators : Optional[Dict[int, torch.Generator]] + generators: Optional[Dict[int, torch.Generator]] A dictionary mapping indices in the batch to `torch.Generator` objects. - device : torch.device + device: torch.device The device on which to allocate the tensor. Returns: - uniform_rand : torch.Tensor + uniform_rand: torch.Tensor A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2a178ddf48777..5dacf60886966 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -239,7 +239,7 @@ class EagleProposer: else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -367,8 +367,7 @@ class EagleProposer: else: last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) + logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -678,9 +677,7 @@ class EagleProposer: # Get the output logits for the draft tokens. logits = self.model.compute_logits( draft_last_hidden_states.reshape(batch_size * level_num_drafts, - -1), - None, - ) + -1)) # Sample a draft token for each child at the next tree level. num_children = self.child_drafts_per_level[level + 1] diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 3e90179e78d99..70b29c05c2a50 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -41,7 +41,7 @@ class MedusaProposer: ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) - logits = self.model.compute_logits(blocks, None) + logits = self.model.compute_logits(blocks) # Get draft tokens and transpose the result # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 233df8f1b0e9b..b0cd0f4133079 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1479,7 +1479,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): Args: scheduler_output: The scheduler output containing scheduled encoder - inputs. + inputs. Returns: A tuple of (mm_kwargs, req_ids_pos) where: @@ -2240,7 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return output sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states) else: # Rare case. assert not self.is_pooling_model @@ -2258,8 +2258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits = None else: sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, - None) + logits = self.model.compute_logits(sample_hidden_states) model_output_broadcast_data = {} if logits is not None: @@ -2706,7 +2705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states, None) + logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want @@ -3105,7 +3104,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # To avoid breaking the sampler, we use a random tensor here instead. hidden_states = torch.rand_like(hidden_states) - logits = self.model.compute_logits(hidden_states, None) + logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) dummy_tensors = lambda v: torch.full( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 48070c1e3e7cb..dd11b1dcbe94c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1692,7 +1692,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: - return self.model.compute_logits(sample_hidden_states, None) + return self.model.compute_logits(sample_hidden_states) # TODO: Under SPMD mode, sample_from_logits has correctness issue. # Re-enable the torch.compile once the issue is fixed in torchxla. diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 021d18b2500f0..af922f9979d1a 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -205,7 +205,8 @@ def gather_mm_placeholders( """ Reconstructs the embeddings from the placeholder tokens. - This is the operation of [scatter_mm_placeholders][]. + This is the operation of [`scatter_mm_placeholders`] + [vllm.v1.worker.utils.scatter_mm_placeholders]. """ if is_embed is None: return placeholders diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py deleted file mode 100644 index 530907012f704..0000000000000 --- a/vllm/worker/cache_engine.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""CacheEngine class for managing the KV cache.""" -from typing import List - -import torch - -from vllm.attention import get_attn_backend -from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig -from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - get_dtype_size, is_pin_memory_available) - -logger = init_logger(__name__) - - -class CacheEngine: - """Manages the KV cache. - - This class is responsible for initializing and managing the GPU and CPU KV - caches. It also provides methods for performing KV cache operations, such - as swapping and copying. - """ - - def __init__( - self, - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, - device_config: DeviceConfig, - ) -> None: - self.cache_config = cache_config - self.model_config = model_config - self.parallel_config = parallel_config - self.device_config = device_config - - self.head_size = model_config.get_head_size() - # Models like Jamba, have mixed typed layers, E.g Mamba - self.num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - - self.block_size = cache_config.block_size - self.num_gpu_blocks = cache_config.num_gpu_blocks - if self.num_gpu_blocks: - self.num_gpu_blocks //= parallel_config.pipeline_parallel_size - self.num_cpu_blocks = cache_config.num_cpu_blocks - if self.num_cpu_blocks: - self.num_cpu_blocks //= parallel_config.pipeline_parallel_size - - if cache_config.cache_dtype == "auto": - self.dtype = model_config.dtype - else: - self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - # Get attention backend. - self.attn_backend = get_attn_backend(self.head_size, - model_config.dtype, - cache_config.cache_dtype, - self.block_size, - model_config.is_attention_free, - use_mla=model_config.use_mla) - - # Initialize the cache. - self.gpu_cache = self._allocate_kv_cache( - self.num_gpu_blocks, self.device_config.device_type) - self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") - - def _allocate_kv_cache( - self, - num_blocks: int, - device: str, - ) -> List[torch.Tensor]: - """Allocates KV cache on the specified device.""" - kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - pin_memory = is_pin_memory_available() if device == "cpu" else False - kv_cache: List[torch.Tensor] = [] - try: - kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order( - ) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape))) - - # The allocation respects the backend-defined stride order to ensure - # the semantic remains consistent for each backend. We first obtain the - # generic kv cache shape and then permute it according to the stride - # order which could result in a non-contiguous tensor. - kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i] - for i in kv_cache_stride_order) - - for _ in range(self.num_attention_layers): - # null block in CpuGpuBlockAllocator requires at least that - # block to be zeroed-out. - # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros( - kv_cache_allocation_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device).permute(*kv_cache_stride_order) - - # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases - # when entry_shape is higher than 1D - kv_cache.append(layer_kv_cache) - return kv_cache - - def swap_in(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_attention_layers): - self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], - src_to_dst) - - def swap_out(self, src_to_dst: torch.Tensor) -> None: - for i in range(self.num_attention_layers): - self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], - src_to_dst) - - def copy(self, src_to_dsts: torch.Tensor) -> None: - self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) - - @staticmethod - def get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, - ) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - key_cache_entry = num_heads * head_size - - # For MLA there is no value cache, since the latent vector - # is joint keys and values. - value_cache_entry = key_cache_entry if not model_config.use_mla else 0 - total = num_attention_layers * cache_config.block_size * \ - (key_cache_entry + value_cache_entry) - - dtype_size = get_dtype_size(dtype) - return dtype_size * total diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py deleted file mode 100644 index f662f5a85eff6..0000000000000 --- a/vllm/worker/model_runner.py +++ /dev/null @@ -1,2023 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import gc -import inspect -import itertools -import time -import weakref -from contextlib import contextmanager -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) - -import numpy as np -import torch -import torch.distributed -import torch.nn as nn -from tqdm.auto import tqdm - -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.abstract import AttentionState -from vllm.attention.backends.utils import CommonAttentionState -from vllm.compilation.counter import compilation_counter -from vllm.config import CompilationLevel, VllmConfig -from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.distributed.kv_transfer import get_kv_transfer_group -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - graph_capture) -from vllm.forward_context import get_forward_context, set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata, SamplingMetadataCache -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, - get_sampler) -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models import (supports_lora, supports_mrope, - supports_multimodal) -from vllm.model_executor.models.utils import set_cpu_offload_max_bytes -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap, - MultiModalRegistry) -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, - async_tensor_h2d, flatten_2d_lists, - is_pin_memory_available, supports_dynamo, - weak_ref_tensor) -from vllm.worker.model_runner_base import ( - InputProcessingError, ModelRunnerBase, ModelRunnerInputBase, - ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -LORA_WARMUP_RANK = 8 - -_NUM_WARMUP_ITERS = 2 - -TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") - -# For now, bump up cache limits for recompilations during CUDA graph warmups. -torch._dynamo.config.cache_size_limit = 128 -torch._dynamo.config.accumulated_cache_size_limit = 128 - - -@dataclass(frozen=True) -class ModelInputForGPU(ModelRunnerInputBase): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - input_tokens: Optional[torch.Tensor] = None - inputs_embeds: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None - request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None - finished_requests_ids: Optional[List[str]] = None - virtual_engine: int = 0 - async_callback: Optional[Callable] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - previous_hidden_states: Optional[torch.Tensor] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForGPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForGPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - # Exclude `async_callback` to be able to pickle this object - def __getstate__(self): - state = self.__dict__.copy() - del state["async_callback"] - return state - - # TODO: What happens when we depickle this object? - # How can we update this callback to properly pass it to the engine? - def __setstate__(self, state): - self.__dict__.update(state) - self.__dict__.update({'async_callback': None}) - - -@dataclass(frozen=True) -class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - # Used for speculative decoding. We do not broadcast it because it is only - # used by the driver worker. - is_prompt: Optional[bool] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForGPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): - """Build ModelInputForGPU from SequenceGroupMetadata.""" - - # Note: ideally we would be using a dataclass(kw_only=True) - # here, so that this can be subclassed easily, - # but kw_only is not supported in python<3.10. - class InterDataForSeqGroup: - """Intermediate data for the current sequence group.""" - - def simple_reinit(self): - self.input_tokens[0].clear() # type: ignore - self.inputs_embeds = None # type: ignore - self.input_positions[0].clear() # type: ignore - self.mrope_input_positions = None # type: ignore - self.seq_lens[0] = 0 # type: ignore - self.orig_seq_lens[0] = 0 # type: ignore - self.prompt_lens[0] = 0 # type: ignore - self.query_lens[0] = 0 # type: ignore - self.context_lens[0] = 0 # type: ignore - self.curr_sliding_window_blocks[0] = 0 # type: ignore - self.lora_index_mapping.clear() # type: ignore - self.lora_prompt_mapping.clear() # type: ignore - self.lora_requests.clear() # type: ignore - - def __init__( - self, - *, - # From sequence group metadata. - request_id: str, - seq_ids: List[int], - is_prompt: bool, - block_tables: Optional[Dict[int, List[int]]], - computed_block_nums: List[int], - n_seqs: int = 0, - - # Input tokens and positions. - input_tokens: Optional[List[List[int]]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - input_positions: Optional[List[List[int]]] = None, - mrope_input_positions: Optional[List[List[List[int]]]] = None, - - # The sequence length (may be capped to the sliding window). - seq_lens: Optional[List[int]] = None, - # The original sequence length (before applying sliding window). - # This is used to compute slot mapping. - orig_seq_lens: Optional[List[int]] = None, - # This is used in the dual-chunk flash attention backend. - prompt_lens: Optional[List[int]] = None, - # The query length. - query_lens: Optional[List[int]] = None, - # The number of tokens that are already computed. - context_lens: Optional[List[int]] = None, - # The current sliding window block. - curr_sliding_window_blocks: Optional[List[int]] = None, - - # LoRA inputs. - lora_index_mapping: Optional[List[List[int]]] = None, - lora_prompt_mapping: Optional[List[List[int]]] = None, - lora_requests: Optional[Set[LoRARequest]] = None, - - # Multi-modal inputs. - multi_modal_kwargs: Optional[MultiModalKwargs] = None, - multi_modal_placeholder_maps: Optional[Dict[ - str, MultiModalPlaceholderMap]] = None, - - # Whether the prefix cache is hit (prefill only). - prefix_cache_hit: bool = False, - reinit: bool = False, - reinit_use_defaults: bool = False, - encoder_seq_len: int = 0, - ): - if reinit: - assert len(self.seq_ids) == len(seq_ids) # type: ignore - for i, seq_id in enumerate(seq_ids): - self.seq_ids[i] = seq_id # type: ignore - else: - self.seq_ids = seq_ids - - self.request_id = request_id - self.is_prompt = is_prompt - self.block_tables = block_tables - self.computed_block_nums = computed_block_nums - self.n_seqs = n_seqs - self.encoder_seq_len = encoder_seq_len - - if reinit: - if len(self.seq_ids) == 1 and reinit_use_defaults: - self.simple_reinit() - else: - if input_tokens: - self.input_tokens = input_tokens - else: - for seq_id in range(len(self.seq_ids)): - self.input_tokens[seq_id].clear() - - self.inputs_embeds = inputs_embeds - - if input_positions: - self.input_positions = input_positions - else: - for seq_id in range(len(self.seq_ids)): - self.input_positions[seq_id].clear() - - self.mrope_input_positions = None - - if seq_lens: - self.seq_lens = seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.seq_lens[seq_id] = 0 - - if orig_seq_lens: - self.orig_seq_lens = orig_seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.orig_seq_lens[seq_id] = 0 - - if prompt_lens: - self.prompt_lens = prompt_lens - else: - for seq_id in range(len(self.seq_ids)): - self.prompt_lens[seq_id] = 0 - - if query_lens: - self.query_lens = query_lens - else: - for seq_id in range(len(self.seq_ids)): - self.query_lens[seq_id] = 0 - - if context_lens: - self.context_lens = context_lens - else: - for seq_id in range(len(self.seq_ids)): - self.context_lens[seq_id] = 0 - - if curr_sliding_window_blocks: - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks - else: - for seq_id in range(len(self.seq_ids)): - self.curr_sliding_window_blocks[seq_id] = 0 - - if lora_index_mapping: - self.lora_index_mapping = lora_index_mapping - else: - self.lora_index_mapping.clear() - - if lora_prompt_mapping: - self.lora_prompt_mapping = lora_prompt_mapping - else: - self.lora_prompt_mapping.clear() - - if lora_requests: - self.lora_requests = lora_requests - else: - self.lora_requests.clear() - - else: - self.input_tokens = input_tokens or [] - self.inputs_embeds = inputs_embeds - self.input_positions = input_positions or [] - self.mrope_input_positions = mrope_input_positions or None - self.seq_lens = seq_lens or [] - self.orig_seq_lens = orig_seq_lens or [] - self.prompt_lens = prompt_lens or [] - self.query_lens = query_lens or [] - self.context_lens = context_lens or [] - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks or [] - - self.lora_index_mapping = lora_index_mapping or [] - self.lora_prompt_mapping = lora_prompt_mapping or [] - self.lora_requests = lora_requests or set() - - self.multi_modal_kwargs = multi_modal_kwargs - self.multi_modal_placeholder_maps = multi_modal_placeholder_maps - self.prefix_cache_hit = prefix_cache_hit - - self.n_seqs = len(self.seq_ids) - - if not reinit: - self.__post_init__() - - def __post_init__(self): - self.n_seqs = len(self.seq_ids) - - self.input_tokens = [[] for _ in range(self.n_seqs)] - self.input_positions = [[] for _ in range(self.n_seqs)] - self.mrope_input_positions = None - self.seq_lens = [0] * self.n_seqs - self.orig_seq_lens = [0] * self.n_seqs - self.prompt_lens = [0] * self.n_seqs - self.query_lens = [0] * self.n_seqs - self.context_lens = [0] * self.n_seqs - self.curr_sliding_window_blocks = [0] * self.n_seqs - - self.lora_index_mapping = [] - self.lora_prompt_mapping = [] - - def __repr__(self) -> str: - return (f"InterDataForSeqGroup(" - f"request_id={self.request_id}, " - f"seq_ids={self.seq_ids}, " - f"is_prompt={self.is_prompt}, " - f"block_tables={self.block_tables}, " - f"computed_block_nums={self.computed_block_nums}, " - f"n_seqs={self.n_seqs}, " - f"input_tokens={self.input_tokens}, " - f"inputs_embeds.shape=" - f"{getattr(self.inputs_embeds, 'shape', None)}, " - f"input_positions={self.input_positions}, " - f"mrope_input_positions={self.mrope_input_positions}, " - f"seq_lens={self.seq_lens}, " - f"orig_seq_lens={self.orig_seq_lens}, " - f"query_lens={self.query_lens}, " - f"context_lens={self.context_lens}, " - f"multi_modal_kwargs={self.multi_modal_kwargs}") - - def gen_inter_data_builder(self, num_seqs: int): - return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( - request_id="", - seq_ids=[0] * num_seqs, - is_prompt=True, - block_tables=None, - computed_block_nums=[]) - - def init_cached_inter_data(self, *args, **kwargs): - assert len(args) == 0 - assert "seq_ids" in kwargs - seq_ids = kwargs["seq_ids"] - num_seqs = len(seq_ids) - - # The inter-data cache is per model_runner - inter_data_cache = self.runner.inter_data_cache - if num_seqs not in inter_data_cache: - inter_data_cache[num_seqs] = PyObjectCache( - self.gen_inter_data_builder(num_seqs)) - - obj = inter_data_cache[num_seqs].get_object() - obj.__init__(*args, **kwargs) - return obj - - def reset_cached_inter_data(self): - for cache in self.runner.inter_data_cache.values(): - cache.reset() - - def __init__(self, - runner: "GPUModelRunnerBase", - finished_requests_ids: Optional[List[str]] = None): - super().__init__() - # Compute functions for each sequence in a sequence group. - # WARNING: The order of the functions matters! - self.per_seq_compute_fns = [ - self._compute_lens, - self._compute_for_prefix_cache_hit, - self._compute_for_sliding_window, - self._compute_lora_input, - ] - # Compute functions for each sequence group. - # WARNING: The order of the functions matters! - self.per_seq_group_compute_fns = [ - self._compute_multi_modal_input, - ] - - self.runner = runner - self.model_input_cls = self.runner._model_input_cls - self.attn_backend = self.runner.attn_backend - self.scheduler_config = self.runner.scheduler_config - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.enable_lora = self.runner.lora_config is not None - - # Attention metadata inputs. - if self.attn_backend is not None: - # spec decode (e.g. Medusa) does not have atten backend - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) - - # Engine/Model configurations. - self.chunked_prefill_enabled = ( - self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled) - if self.sliding_window is not None: - self.sliding_window_blocks = ( - self.sliding_window + self.block_size - 1) // self.block_size - self.block_aligned_sliding_window = \ - self.sliding_window_blocks * self.block_size - - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.finished_requests_ids = finished_requests_ids - - # if the current batch is decode-only. - # will be set to False if there is any non-decode request. - self.decode_only = True - - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - - self.attn_metadata_builder.prepare() - - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Compute context length, sequence length and tokens - for the given sequence data. - """ - seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] - token_chunk_size = seq_group_metadata.token_chunk_size - - # Compute context length (the number of tokens that are - # already computed) and sequence length (total number of tokens). - - seq_len = seq_data.get_len() - if inter_data.is_prompt: - context_len = seq_data.get_num_computed_tokens() - seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.model_config.is_encoder_decoder: - context_len = seq_len - 1 - else: - context_len = seq_data.get_num_computed_tokens() - - # Compute tokens. - if seq_data.prompt_embeds is None: - tokens = seq_data.get_token_ids()[context_len:seq_len] - prompt_embeds = None - else: - tokens = [0] * (seq_len - context_len) - prompt_embeds = seq_data.get_token_embeddings( - )[context_len:seq_len] - - inter_data.seq_lens[seq_idx] = seq_len - inter_data.orig_seq_lens[seq_idx] = seq_len - inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() - inter_data.context_lens[seq_idx] = context_len - inter_data.input_tokens[seq_idx].extend(tokens) - inter_data.inputs_embeds = prompt_embeds - inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) - inter_data.query_lens[seq_idx] = seq_len - context_len - - if seq_data.mrope_position_delta is not None: - if inter_data.mrope_input_positions is None: - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - - inter_data.mrope_input_positions[ - seq_idx] = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - - def _compute_for_prefix_cache_hit( - self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Check if hit prefix cache (i.e., some blocks are already computed). - If hit, update input tokens and positions to only compute the - remaining blocks. - """ - computed_block_nums = inter_data.computed_block_nums - - # Note that prefix caching does not support sliding window. - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and inter_data.is_prompt) - inter_data.prefix_cache_hit = prefix_cache_hit - - if not prefix_cache_hit: - return - - assert computed_block_nums is not None - # The cache hit prompt tokens in this sequence. Note that - # this may be larger than the sequence length if chunked - # prefill is enabled. - prefix_cache_len = len(computed_block_nums) * self.block_size - seq_group_metadata.seq_data[inter_data.seq_ids[ - seq_idx]].update_num_cached_tokens(prefix_cache_len) - - # The number of so far computed prompt tokens in this sequence. - context_len = inter_data.context_lens[seq_idx] - # The total number of prompt tokens in this sequence. - # When chunked prefill is enabled, this is the token number of - # computed chunks + current chunk. - seq_len = inter_data.seq_lens[seq_idx] - if prefix_cache_len <= context_len: - # We already passed the cache hit region, - # so do normal computation. - pass - elif context_len < prefix_cache_len < seq_len: - # Partial hit. Compute the missing part. - uncomputed_start = prefix_cache_len - context_len - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][uncomputed_start:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][uncomputed_start:] - context_len = prefix_cache_len - - inter_data.context_lens[seq_idx] = context_len - inter_data.query_lens[ - seq_idx] = inter_data.seq_lens[seq_idx] - context_len - elif seq_len <= prefix_cache_len: - # Full hit. Only compute the last token to avoid - # erroneous behavior. FIXME: Ideally we should directly - # mark all tokens as computed in the scheduler and do not - # schedule this sequence, so this case should not happen. - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][-1:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][-1:] - inter_data.query_lens[seq_idx] = 1 - inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 - - def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Update seq_len and curr_sliding_window_block for the given - sequence data (only required by decoding) if sliding window is enabled. - """ - curr_sliding_window_block = 0 - sliding_seq_len = inter_data.seq_lens[seq_idx] - if not inter_data.is_prompt and self.sliding_window is not None: - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - curr_sliding_window_block = self.sliding_window_blocks - # number of elements in last block - suff_len = inter_data.seq_lens[seq_idx] % self.block_size - sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_block += 1 - - inter_data.curr_sliding_window_blocks[ - seq_idx] = curr_sliding_window_block - inter_data.seq_lens[seq_idx] = sliding_seq_len - - def _compute_lora_input(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """If LoRA is enabled, compute LoRA index and prompt mapping.""" - if not self.enable_lora: - return - - lora_id = seq_group_metadata.lora_int_id - if lora_id > 0: - inter_data.lora_requests.add(seq_group_metadata.lora_request) - query_len = inter_data.query_lens[seq_idx] - inter_data.lora_index_mapping.append([lora_id] * query_len) - sampling_params = seq_group_metadata.sampling_params - if sampling_params and sampling_params.prompt_logprobs is not None: - inter_data.lora_prompt_mapping.append([lora_id] * query_len) - elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample: - inter_data.lora_prompt_mapping.append([lora_id]) - else: - inter_data.lora_prompt_mapping.append([]) - - def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If multi-modal data is given, add it to the input.""" - # NOTE: mm_kwargs only includes the subset of multi-modal items that - # intersect with the current prefill positions. - positions = inter_data.input_positions[0] - mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, - range(positions[0], positions[0] + len(positions))) - - # M-RoPE requires mrope_positions even for plain text; return early - # when mm_kwargs is empty only if inter_data.is_prompt is False. - if not mm_kwargs and not inter_data.is_prompt: - return - - inter_data.multi_modal_kwargs = mm_kwargs - inter_data.multi_modal_placeholder_maps = placeholder_maps - - # special processing for mrope position deltas. - if self.runner.model_config.uses_mrope: - image_grid_thw = mm_kwargs.get("image_grid_thw", None) - video_grid_thw = mm_kwargs.get("video_grid_thw", None) - audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", - None) - - second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) - use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) - hf_config = self.runner.model_config.hf_config - - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - for seq_idx in range(inter_data.n_seqs): - seq_data = seq_group_metadata.seq_data[ - inter_data.seq_ids[seq_idx]] - token_ids = seq_data.get_token_ids() - - if supports_mrope(self.runner.model): - mrope_input_positions, mrope_position_delta = \ - self.runner.model.get_mrope_input_positions( - token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=inter_data.context_lens[seq_idx], - seq_len=inter_data.seq_lens[seq_idx], - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - mrope_input_positions = mrope_input_positions.tolist() - else: - mrope_input_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=inter_data.context_lens[seq_idx], - seq_len=inter_data.seq_lens[seq_idx], - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - seq_data.mrope_position_delta = mrope_position_delta - inter_data.mrope_input_positions[ - seq_idx] = mrope_input_positions - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - """Add a sequence group to the builder.""" - seq_ids = seq_group_metadata.seq_data.keys() - n_seqs = len(seq_ids) - is_prompt = seq_group_metadata.is_prompt - - if is_prompt: - assert n_seqs == 1 - self.decode_only = False - - encoder_seq_len = 0 - - if self.runner.model_config.is_encoder_decoder: - encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() - - inter_data = self.init_cached_inter_data( - request_id=seq_group_metadata.request_id, - seq_ids=seq_ids, - is_prompt=is_prompt, - block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums, - reinit=True, - reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) - - self.inter_data_list.append(inter_data) - - for seq_idx in range(n_seqs): - for per_seq_fn in self.per_seq_compute_fns: - per_seq_fn(inter_data, seq_idx, seq_group_metadata) - for per_seq_group_fn in self.per_seq_group_compute_fns: - per_seq_group_fn(inter_data, seq_group_metadata) - - def _use_captured_graph(self, - batch_size: int, - decode_only: bool, - max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> bool: - return (decode_only and not self.runner.model_config.enforce_eager - and max_decode_seq_len <= self.runner.max_seq_len_to_capture - and max_encoder_seq_len <= self.runner.max_seq_len_to_capture - and batch_size <= self.runner.max_batchsize_to_capture) - - def _get_cuda_graph_pad_size(self, - num_seqs: int, - max_decode_seq_len: int, - max_encoder_seq_len: int = 0) -> int: - """ - Determine the number of padding sequences required for running in - CUDA graph mode. Returns -1 if CUDA graphs cannot be used. - - In the multi-step + chunked-prefill case, only the first step - has Prefills (if any). The rest of the steps are guaranteed to be all - decodes. In this case, we set up the padding as if all the sequences - are decodes so we may run all steps except the first step in CUDA graph - mode. - - Args: - num_seqs (int): Number of sequences scheduled to run. - max_decode_seq_len (int): Greatest of all the decode sequence - lengths. Used only in checking the viablility of using - CUDA graphs. - max_encoder_seq_len (int, optional): Greatest of all the encode - sequence lengths. Defaults to 0. Used only in checking the - viability of using CUDA graphs. - Returns: - int: Returns the determined number of padding sequences. If - CUDA graphs is not viable, returns -1. - """ - decode_only = self.decode_only - if not decode_only: - # Early exit so we can treat num_seqs as the batch_size below. - return -1 - - # batch_size out of this function refers to the number of input - # tokens being scheduled. This conflation of num_seqs as batch_size - # is valid as this is a decode-only case. - batch_size = num_seqs - if not self._use_captured_graph(batch_size, decode_only, - max_decode_seq_len, - max_encoder_seq_len): - return -1 - - graph_batch_size = self.runner.vllm_config.pad_for_cudagraph( - batch_size) - assert graph_batch_size >= batch_size - return graph_batch_size - batch_size - - def build(self) -> ModelInputForGPU: - """Finalize the builder intermediate data and - create on-device tensors. - """ - # Combine and flatten intermediate data. - input_tokens = list[int]() - inputs_embeds_list = list[torch.Tensor]() - for inter_data in self.inter_data_list: - for cur_input_tokens in inter_data.input_tokens: - input_tokens.extend(cur_input_tokens) - if inter_data.inputs_embeds is not None: - inputs_embeds_list.append( - inter_data.inputs_embeds.to( - dtype=self.runner.model_config.dtype, - device=self.runner.device)) - inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_list) == 0: - inputs_embeds = None - else: - inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( - dtype=self.runner.model_config.dtype, - device=self.runner.device) - assert len(inputs_embeds) == len(input_tokens) - - if not input_tokens and inputs_embeds is None: - # This may happen when all prefill requests hit - # prefix caching and there is no decode request. - return self.model_input_cls() - - mrope_input_positions: Optional[List[List[int]]] = None - if any(inter_data.mrope_input_positions is not None - for inter_data in self.inter_data_list): - mrope_input_positions = [[] for _ in range(3)] - for idx in range(3): - for inter_data in self.inter_data_list: - msections = inter_data.mrope_input_positions - if msections is None: - for _seq_input_positions in inter_data.input_positions: - mrope_input_positions[idx].extend( - _seq_input_positions) - else: - for _seq_mrope_input_positions in msections: - mrope_input_positions[idx].extend( - _seq_mrope_input_positions[idx]) - input_positions = None - else: - input_positions = [] - for inter_data in self.inter_data_list: - for cur_input_positions in inter_data.input_positions: - input_positions.extend(cur_input_positions) - - seq_lens = [] - query_lens = [] - max_decode_seq_len = 0 - max_encoder_seq_len = 0 - for inter_data in self.inter_data_list: - seq_lens.extend(inter_data.seq_lens) - query_lens.extend(inter_data.query_lens) - if not inter_data.is_prompt: - max_decode_seq_len = max(max_decode_seq_len, - max(inter_data.seq_lens)) - if self.runner.model_config.is_encoder_decoder: - max_encoder_seq_len = max(max_encoder_seq_len, - inter_data.encoder_seq_len) - - # Mapping from request IDs to sequence IDs. Used for Jamba models - # that manages the cache by itself. - request_ids_to_seq_ids = { - data.request_id: data.seq_ids - for data in self.inter_data_list - } - - cuda_graph_pad_size = self._get_cuda_graph_pad_size( - num_seqs=len(seq_lens), - max_decode_seq_len=max_decode_seq_len, - max_encoder_seq_len=max_encoder_seq_len) - - batch_size = len(input_tokens) - if cuda_graph_pad_size != -1: - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - batch_size += cuda_graph_pad_size - - # Tokens and positions. - if cuda_graph_pad_size: - input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) - assert self.runner.device is not None - input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, - self.runner.device, - self.runner.pin_memory) - - if mrope_input_positions is not None: - for idx in range(3): - mrope_input_positions[idx].extend( - itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(mrope_input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) - else: - input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) - input_positions_tensor = async_tensor_h2d(input_positions, - torch.long, - self.runner.device, - self.runner.pin_memory) - # Sequence and query lengths. - if cuda_graph_pad_size: - seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) - - # Attention metadata. - attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, cuda_graph_pad_size, batch_size) - - # LoRA data. - lora_requests = set() - lora_mapping = None - if self.enable_lora: - lora_requests = set(r for data in self.inter_data_list - for r in data.lora_requests) - lora_index_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_index_mapping) - for inter_data in self.inter_data_list - ]) - if cuda_graph_pad_size: - lora_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - lora_prompt_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_prompt_mapping) - for inter_data in self.inter_data_list - ]) - - lora_mapping = LoRAMapping( - **dict(index_mapping=lora_index_mapping, - prompt_mapping=lora_prompt_mapping, - is_prefill=not self.decode_only)) - - # Multi-modal data. - multi_modal_kwargs_list = [ - data.multi_modal_kwargs for data in self.inter_data_list - if data.multi_modal_kwargs is not None - ] - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return self.model_input_cls( - input_tokens=input_tokens_tensor, - inputs_embeds=inputs_embeds, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=self.finished_requests_ids) - - -class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): - """ - Helper class for shared methods between GPU model runners. - """ - _model_input_cls: Type[TModelInputForGPU] - _builder_cls: Type[ModelInputForGPUBuilder] - builder: ModelInputForGPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - return_hidden_states: bool = False, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - - ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - cache_config = self.cache_config - - self.is_driver_worker = is_driver_worker - self.return_hidden_states = return_hidden_states - - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = \ - self.vllm_config.compilation_config.max_capture_size - - # - self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [ - {} for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.graph_memory_pool: Optional[Tuple[ - int, int]] = None # Set during graph capture. - - self.has_inner_state = model_config.has_inner_state - - self.in_profile_run = False - - # When using CUDA graph, the input block tables must be padded to - # max_seq_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max seq len to capture / block size). - self.graph_block_tables = np.zeros( - (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) - - self.cross_layer_shared_graph_block_tables = np.zeros( - (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) - - # Attention-free but stateful models like Mamba need a placeholder attn - # backend, as the attention metadata is needed to manage internal state. - # However we must bypass attention selection altogether for some models - # used for speculative decoding to avoid a divide-by-zero in - # model_config.get_head_size() - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - needs_attn_backend = (num_attn_heads != 0 - or self.model_config.is_attention_free) - - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) if needs_attn_backend else None - if self.attn_backend: - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) - else: - self.attn_state = CommonAttentionState(weakref.proxy(self)) - - # Multi-modal data support - self.input_registry = input_registry - self.mm_registry = mm_registry - - # Lazy initialization - self.model: nn.Module # Set after load_model - # Set after load_model. - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self.sampler = get_sampler() - - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - - # Used to cache python objects - self.inter_data_cache: Dict[int, PyObjectCache] = {} - - # Using the PythonizationCache in Pipeline-Parallel clobbers the - # SequenceGroupToSample object. In Pipeline-Parallel, we have - # more than 1 Scheduler, resulting in a potential back-to-back - # prepare_model_inputs() call. This clobbers the cached - # SequenceGroupToSample objects, as we reset the cache during - # every prepare_model_inputs() call. - self.sampling_metadata_cache: SamplingMetadataCache = \ - SamplingMetadataCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None - - if hasattr(self, "_builder_cls"): - # multi-step model runner does not have `_builder_cls` - self.builder = self._builder_cls(weakref.proxy(self)) - - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler(self.device) as m: - time_before_load = time.perf_counter() - self.model = get_model(vllm_config=self.vllm_config) - if self.lora_config: - assert supports_lora( - self.model - ), f"{self.model.__class__.__name__} does not support LoRA yet." - - if supports_multimodal(self.model): - logger.warning( - "Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.vllm_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - ) - - self.model = self.lora_manager.create_lora_manager(self.model) - time_after_load = time.perf_counter() - - self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) - - - if self.vllm_config.compilation_config.level ==\ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) - compilation_counter.dynamo_as_is_count += 1 - self.model = torch.compile(self.model, - fullgraph=True, - backend=backend) - - def get_model(self) -> nn.Module: - return self.model - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - from vllm.model_executor.model_loader import ShardedStateLoader - ShardedStateLoader.save_model( - self.model, - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - from vllm.model_executor.model_loader import TensorizerLoader - TensorizerLoader.save_model( - self.model, - tensorizer_config=tensorizer_config, - model_config=self.model_config, - ) - - def get_max_block_per_batch(self) -> int: - block_size = self.block_size - return (self.max_seq_len_to_capture + block_size - 1) // block_size - - def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForGPU: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - self.builder.prepare(finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - try: - self.builder.add_seq_group(seq_group_metadata) - except Exception as e: - # Raise an exception that tracks the ID of the bad request - raise InputProcessingError(seq_group_metadata.request_id, - str(e)) from e - - self.builder.reset_cached_inter_data() - - return self.builder.build() # type: ignore - - @contextmanager - def set_in_profile_run(self): - self.in_profile_run = True - try: - yield - finally: - self.in_profile_run = False - - @torch.inference_mode() - def profile_run(self) -> None: - max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - self._dummy_run(max_num_batched_tokens, max_num_seqs) - - def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]: - assert num_loras > 0 - assert self.lora_manager is not None - - dummy_lora_requests: list[LoRARequest] = [] - with self.lora_manager.dummy_lora_cache(): - for idx in range(num_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - return dummy_lora_requests - - def _remove_dummy_loras(self): - # Remove dummy loras. - assert self.lora_manager is not None - self.remove_all_loras() - - def _dummy_run(self, - max_num_batched_tokens: int, - max_num_seqs: int = 1) -> None: - with self.set_in_profile_run(): - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = \ - SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - - # This represents the maximum number of different requests - # that will have unique loras, and therefore the max amount of - # memory consumption. Create dummy lora request copies from the - # lora request passed in, which contains a lora from the lora - # warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - dummy_lora_requests = self._add_dummy_loras( - self.lora_config.max_loras) - assert len(dummy_lora_requests) == self.lora_config.max_loras - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the - # total number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, - # which needs to be accounted for when calculating the GPU blocks - # for vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data. - multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = \ - self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - # Disable KV Scale Calculation for dummy data during profile run - if model_input.attn_metadata is not None: - model_input.attn_metadata.enable_kv_scales_calculation = False - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - if self.lora_config: - self._remove_dummy_loras() - - return - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - - @torch.inference_mode() - def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int: - """Cuda graph capture a model and return cudagraph memory - consumption in bytes. - - Note that CUDA graph's performance gain is negligible if number - of batched tokens are larger than 200. And since CUDA graph - requires fixed sized tensors, supporting large/variable batch - size requires high GPU memory overhead. Thus, vLLM only captures - decoding requests. Mixed batch (chunked prefill + decoding) or - prefill requests are not captured. - - Since it is used for decoding-only, it assumes there's only 1 token - per sequence in the batch. - """ - assert not self.model_config.enforce_eager - logger.info("Capturing cudagraphs for decoding. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI. " - "If out-of-memory error occurs during cudagraph capture," - " consider decreasing `gpu_memory_utilization` or " - "switching to eager mode. You can also reduce the " - "`max_num_seqs` as needed to decrease memory usage.") - start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] - - # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = self.max_batchsize_to_capture - input_tokens = torch.zeros(max_batch_size, - dtype=torch.long, - device=self.device) - input_positions = torch.zeros(max_batch_size, - dtype=torch.long, - device=self.device) - inputs_embeds = torch.zeros( - (max_batch_size, self.model_config.get_hidden_size()), - dtype=self.model_config.dtype, - device=self.device) - if self.model_config.uses_mrope: - input_positions = torch.tile(input_positions, - (3, 1)).cuda(device=self.device) - # Prepare dummy previous_hidden_states only if needed by the model. - # This is used by draft models such as EAGLE. - previous_hidden_states = None - if "previous_hidden_states" in inspect.signature( - self.model.forward).parameters: - previous_hidden_states = torch.empty( - [max_batch_size, - self.model_config.get_hidden_size()], - dtype=self.model_config.dtype, - device=self.device) - - intermediate_inputs = None - if not get_pp_group().is_first_rank: - intermediate_inputs = self.model.make_empty_intermediate_tensors( - batch_size=max_batch_size, - dtype=self.model_config.dtype, - device=self.device) - - dummy_lora_id: Optional[int] = None - dummy_lora_request: LoRARequest = [] - if self.lora_config: - # The goal is to capture the LoRA kernels in cuda graphs. - # for this purpose, as single dummy lora is sufficient. - dummy_lora_requests = self._add_dummy_loras(num_loras=1) - assert len(dummy_lora_requests) == 1 - dummy_lora_request = dummy_lora_requests[0] - dummy_lora_id = dummy_lora_request.lora_int_id - - with self.attn_state.graph_capture(max_batch_size), graph_capture( - self.device) as graph_capture_context: - # NOTE: Capturing the largest batch size first may help reduce the - # memory usage of CUDA graph. - for virtual_engine in range( - self.parallel_config.pipeline_parallel_size): - # We need to not only iterate over batch sizes, but also whether - # to use inputs_embeds or not, hence we use the cartesian - # product. - cudagraph_capture_sizes = self.vllm_config.compilation_config\ - .cudagraph_capture_sizes - cudagraph_inputs_embeds = (( - True, False) if self.model_config.enable_prompt_embeds else - (False, )) - compilation_cases = itertools.product( - cudagraph_capture_sizes, - cudagraph_inputs_embeds, - ) - # Only rank 0 should print progress bar during capture - if get_tensor_model_parallel_rank() == 0: - compilation_cases = tqdm( - list(compilation_cases), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes") - for batch_size, use_inputs_embeds in compilation_cases: - attn_metadata = ( - self.attn_state.graph_capture_get_metadata_for_batch( - batch_size, - is_encoder_decoder_model=self.model_config. - is_encoder_decoder)) - # Disable KV Scale Calculation for graph capture - attn_metadata.enable_kv_scales_calculation = False - if self.lora_config: - lora_mapping = LoRAMapping( - **dict(index_mapping=[dummy_lora_id] * batch_size, - prompt_mapping=[dummy_lora_id] * batch_size, - is_prefill=False)) - self.set_active_loras(set([dummy_lora_request]), - lora_mapping) - - graph_runner = CUDAGraphRunner( - self.model, self.attn_backend.get_name(), - self.attn_state.graph_clone(batch_size), - self.model_config.is_encoder_decoder) - - capture_inputs = { - "input_ids": - input_tokens[:batch_size], - "inputs_embeds": - inputs_embeds[:batch_size] - if use_inputs_embeds else None, - "positions": - input_positions[..., :batch_size], - "intermediate_inputs": - intermediate_inputs[:batch_size] - if intermediate_inputs is not None else None, - "kv_caches": - kv_caches[virtual_engine], - "attn_metadata": - attn_metadata, - "memory_pool": - self.graph_memory_pool, - "stream": - graph_capture_context.stream - } - if previous_hidden_states is not None: - capture_inputs[ - "previous_hidden_states"] = previous_hidden_states[: - batch_size] - - if self.has_inner_state: - # Only used by Mamba-based models CUDA graph atm (Jamba) - capture_inputs.update({ - "seqlen_agnostic_capture_inputs": - self.model.get_seqlen_agnostic_capture_inputs( - batch_size) - }) - if self.model_config.is_encoder_decoder: - # add the additional inputs to capture for - # encoder-decoder models. - self._update_inputs_to_capture_for_enc_dec_model( - capture_inputs) - - with set_forward_context(attn_metadata, self.vllm_config, - virtual_engine): - graph_runner.capture(**capture_inputs) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][( - batch_size, use_inputs_embeds)] = graph_runner - - if self.lora_config: - self._remove_dummy_loras() - - end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] - elapsed_time = end_time - start_time - cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory - # This usually takes < 10 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / GiB_bytes) - return cuda_graph_size - - def _update_inputs_to_capture_for_enc_dec_model(self, - capture_inputs: Dict[str, - Any]): - """ - Updates the set of input tensors needed for CUDA graph capture in an - encoder-decoder model. - - This method modifies the provided `capture_inputs` dictionary by - adding tensors specific to encoder-decoder specific models that - need to be captured for CUDA Graph replay. - """ - # During the decode phase encoder_input_ids and encoder_positions are - # unset. Do the same thing for graph capture. - capture_inputs["encoder_input_ids"] = torch.tensor([], - dtype=torch.long, - device=self.device) - capture_inputs["encoder_positions"] = torch.tensor([], - dtype=torch.long, - device=self.device) - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - -class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - """ - GPU model runner with sampling step. - """ - _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( - ModelInputForGPUWithSamplingMetadata) - _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> ModelInputForGPUWithSamplingMetadata: - model_input = \ - ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - return model_input - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - if get_pp_group().is_last_rank: - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory, - generators, self.sampling_metadata_cache) - else: - sampling_metadata = None - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForGPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in ModelRunner") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - self.attn_state.begin_forward(model_input) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - # TODO(andoorve): We can remove this once all - # virtual engines share the same kv cache. - virtual_engine = model_input.virtual_engine - previous_hidden_states = kwargs.get("previous_hidden_states") - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - use_inputs_embeds = model_input.inputs_embeds is not None - model_executable = self.graph_runners[virtual_engine][( - graph_batch_size, use_inputs_embeds)] - if previous_hidden_states is not None: - previous_hidden_states = torch.cat([ - previous_hidden_states, - torch.empty([ - graph_batch_size - previous_hidden_states.shape[0], - *previous_hidden_states.shape[1:] - ], - dtype=previous_hidden_states.dtype, - device=previous_hidden_states.device) - ]) - else: - model_executable = self.model - - # Receive KV cache in distributed KV cache transfer setting - # In disagg prefill setting, it will also recv hidden states and bypass - # model forwarding - # In KV cache database setting, it will change the model input so that - # we can skip prefilling on tokens that successfully received KV caches - # NOTE: The receive operation is blocking - bypass_model_exec = False - if self.need_recv_kv(model_input, kv_caches): - hidden_or_intermediate_states, bypass_model_exec, model_input = \ - get_kv_transfer_group().recv_kv_caches_and_hidden_states( - # model is used to know which layer the current worker - # is working on, so that we can receive KV for only those - # layers. - model_executable, - model_input, - kv_caches=kv_caches - ) - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - model_kwargs = {} - if previous_hidden_states is not None: - model_kwargs["previous_hidden_states"] = previous_hidden_states - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.cuda.Event(enable_timing=True) - model_forward_end = torch.cuda.Event(enable_timing=True) - model_forward_start.record() - - if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - inputs_embeds=model_input.inputs_embeds, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **seqlen_agnostic_kwargs, - **model_kwargs, - ) - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.record() - - # Sending KV cache in distributed KV cache transfer setting - # NOTE: the send operation is non-blocking - if self.need_send_kv(model_input, kv_caches): - get_kv_transfer_group().send_kv_caches_and_hidden_states( - # model_executable is used to know which layer the current - # worker is working on, so that we can send KV for only those - # layers. - model_executable, - model_input, - kv_caches, - hidden_or_intermediate_states, - ) - - # Compute the logits in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors["model_forward_time"] = ( - torch.tensor(model_forward_time + orig_model_forward_time)) - return hidden_or_intermediate_states - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if self.is_driver_worker: - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - assert isinstance(self.sampler, Sampler) - orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor - if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = True - - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the - # latency from the start time of the driver worker to the end - # time of the driver worker. The model forward time will then - # end up covering the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) - - if model_input.inputs_embeds is not None: - if self.is_driver_worker: - sampled_token_ids = [] - valid_outputs = [] - for sequence_group_output in output.outputs: - if len(sequence_group_output.samples) == 0: - continue - assert len(sequence_group_output.samples) == 1 - valid_outputs.append(sequence_group_output) - sampled_token_ids.append( - sequence_group_output.samples[0].output_token) - sampled_token_ids = torch.tensor(sampled_token_ids).to( - self.device) - sampled_token_ids = broadcast_tensor_dict( - {"sampled_token_ids": - sampled_token_ids})["sampled_token_ids"] - else: - sampled_token_ids = broadcast_tensor_dict( - )["sampled_token_ids"] - if len(sampled_token_ids) > 0: - sampled_token_embeds = \ - self.model.get_input_embeddings(sampled_token_ids) - if self.is_driver_worker: - self.sampler.include_gpu_probs_tensor = \ - orig_include_gpu_probs - for i, sequence_group_output in enumerate(valid_outputs): - sequence_group_output.samples[0].output_embed = \ - sampled_token_embeds[i] - - if not self.is_driver_worker: - return [] - - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert model_input.sampling_metadata is not None - indices = model_input.sampling_metadata.selected_token_indices - if model_input.is_prompt: - hidden_states = hidden_or_intermediate_states.index_select( - 0, indices) - output.prefill_hidden_states = hidden_or_intermediate_states - elif decode_meta.use_cuda_graph: - hidden_states = hidden_or_intermediate_states[:len(indices)] - else: - hidden_states = hidden_or_intermediate_states - - output.hidden_states = hidden_states - - return [output] - - def need_recv_kv(self, model_input, kv_caches) -> bool: - """Check if we need to receive kv-cache from the other worker. - We need to receive KV when - 1. current vLLM instance is KV cache consumer/decode vLLM instance - 2. this batch is not a profiling run - 3. this batch is a prefill run - - Args: - model_input: input to the model executable - kv_caches: vLLM's paged memory - """ - - if self.vllm_config.kv_transfer_config is None: - return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches[0].numel() == 0) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return self.vllm_config.kv_transfer_config.is_kv_consumer and ( - not is_profile_run) and is_prefill_run - - def need_send_kv(self, model_input, kv_caches) -> bool: - """Check if we need to send kv-cache to the other worker. - We need to send KV when - 1. current vLLM instance is KV cache producer/prefill vLLM instance - 2. this batch is not a profiling run - 3. this batch is a prefill run - - Args: - model_input: input to the model executable - kv_caches: vLLM's paged memory - """ - - if self.vllm_config.kv_transfer_config is None: - return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches[0].numel() == 0) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return self.vllm_config.kv_transfer_config.is_kv_producer and ( - not is_profile_run) and is_prefill_run - - -# NOTE: this is nn.Module so the profiler can properly capture/group -# kernels calls made within the graph -class CUDAGraphRunner(nn.Module): - - def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState, is_encoder_decoder_model: bool): - super().__init__() - self.model = model - self.backend_name = backend_name - self.attn_state = attn_state - - self.input_buffers: Dict[str, torch.Tensor] = {} - self.output_buffers: Dict[str, torch.Tensor] = {} - - self._graph: Optional[torch.cuda.CUDAGraph] = None - self._is_encoder_decoder_model = is_encoder_decoder_model - - @property - def graph(self): - assert self._graph is not None - return self._graph - - def capture( - self, - input_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_inputs: Optional[IntermediateTensors], - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - memory_pool: Optional[Tuple[int, int]], - stream: torch.cuda.Stream, - **kwargs, - ): - assert self._graph is None - # Run the model a few times without capturing the graph. - # This is to make sure that the captured graph does not include the - # kernel launches for initial benchmarking (e.g., Triton autotune). - # Note one iteration is not enough for torch.compile - for _ in range(_NUM_WARMUP_ITERS): - self.model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - positions=positions, - intermediate_tensors=intermediate_inputs, - **kwargs, - ) - # Wait for the warm up operations to finish before proceeding with - # Graph Capture. - torch.cuda.synchronize() - # Capture the graph. - self._graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - output_hidden_or_intermediate_states = self.model( - input_ids=input_ids, - **({ - "inputs_embeds": inputs_embeds, - } if inputs_embeds is not None else {}), - positions=positions, - intermediate_tensors=intermediate_inputs, - **kwargs, - ) - - if isinstance(output_hidden_or_intermediate_states, torch.Tensor): - hidden_or_intermediate_states = weak_ref_tensor( - output_hidden_or_intermediate_states) - elif isinstance(output_hidden_or_intermediate_states, - IntermediateTensors): - hidden_or_intermediate_states = IntermediateTensors( - tensors={ - key: weak_ref_tensor(value) - for key, value in - output_hidden_or_intermediate_states.tensors.items() - }) - - del output_hidden_or_intermediate_states - # make sure `output_hidden_or_intermediate_states` is deleted - # in the graph's memory pool - gc.collect() - torch.cuda.synchronize() - - # Save the input and output buffers. - self.input_buffers = { - "input_ids": - input_ids, - **({ - "inputs_embeds": inputs_embeds, - } if inputs_embeds is not None else {}), - "positions": - positions, - "kv_caches": - kv_caches, - **self.attn_state.get_graph_input_buffers( - attn_metadata, self._is_encoder_decoder_model), - **kwargs, - } - if intermediate_inputs is not None: - self.input_buffers.update(intermediate_inputs.tensors) - if get_pp_group().is_last_rank: - self.output_buffers = { - "hidden_states": hidden_or_intermediate_states - } - else: - self.output_buffers = hidden_or_intermediate_states - - def forward( - self, - input_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - **kwargs, - ) -> torch.Tensor: - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - # Copy the input tensors to the input buffers. - self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) - if positions is not None: - # in some case like MLA, it will reuse positions in metadata - # but truncate them to the original size - # so the shape is not padded, we need to copy partial only - self.input_buffers["positions"][:positions.shape[0]].copy_( - positions, non_blocking=True) - if inputs_embeds is not None: - self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_( - inputs_embeds, non_blocking=True) - - if self.backend_name != "NO_ATTENTION": - self.input_buffers["slot_mapping"].copy_( - attn_metadata.slot_mapping, non_blocking=True) - - self.attn_state.prepare_graph_input_buffers( - self.input_buffers, attn_metadata, self._is_encoder_decoder_model) - - if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_inputs_before_cuda_graphs(self.input_buffers, - **kwargs) - - if "previous_hidden_states" in self.input_buffers: - self.input_buffers["previous_hidden_states"].copy_( - kwargs["previous_hidden_states"], non_blocking=True) - - if intermediate_tensors is not None: - for key in intermediate_tensors.tensors: - if key != "model_execute_time" and key != "model_forward_time": - self.input_buffers[key].copy_(intermediate_tensors[key], - non_blocking=True) - if self._is_encoder_decoder_model: - self.input_buffers["encoder_input_ids"].copy_( - kwargs['encoder_input_ids'], non_blocking=True) - self.input_buffers["encoder_positions"].copy_( - kwargs['encoder_positions'], non_blocking=True) - - # Run the graph. - self.graph.replay() - # Return the output tensor. - if get_pp_group().is_last_rank: - return self.output_buffers["hidden_states"] - - return self.output_buffers diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py deleted file mode 100644 index 1008b743619a4..0000000000000 --- a/vllm/worker/model_runner_base.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) - -import torch -import torch.nn as nn - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.models.interfaces import supports_transcription -from vllm.model_executor.models.interfaces_base import is_text_generation_model -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.tasks import GenerationTask, SupportedTask - -if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.model_executor import SamplingMetadata - -logger = init_logger(__name__) - -T = TypeVar('T', bound="BroadcastableModelInput") - - -def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - attn_metadata: Optional["AttentionMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - AttentionMetadata fields. - """ - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - -def _init_attn_metadata_from_tensor_dict( - attn_backend: "AttentionBackend", - tensor_dict: Dict[str, Any], -) -> Dict[str, Any]: - """ - Helper method to initialize AttentionMetadata based on an - AttentionBackend and broadcastable AttentionMetadata fields. - """ - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict: - if field.name == "input_positions": - valid_attn_kwargs[field.name] = tensor_dict[field.name] - else: - valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - tensor_dict["attn_metadata"] = attn_metadata - return tensor_dict - - -def _init_sampling_metadata_from_tensor_dict( # type: ignore - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize SamplingMetadata based on broadcastable - SamplingMetadata fields. - """ - from vllm.model_executor import SamplingMetadata - - selected_token_indices = tensor_dict.pop("selected_token_indices", None) - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - if selected_token_indices is not None: - tensor_dict["sampling_metadata"] = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - return tensor_dict - - -def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - - -def _init_frozen_model_input_from_tensor_dict( - frozen_model_input_cls: Type["ModelRunnerInputBase"], - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize a frozen ModelInput based on broadcastable - """ - valid_tensor_kwargs = {} - for field in dataclasses.fields(frozen_model_input_cls): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_tensor_kwargs[field.name] = val - - frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) - tensor_dict["frozen_model_input"] = frozen_model_input - return tensor_dict - - -class BroadcastableModelInput(ABC): - - @abstractmethod - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def from_broadcasted_tensor_dict( - cls: Type[T], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> T: - """ - Pop fields from the given tensor_dict and populate a new instance of - BroadcastableModelInput. - """ - raise NotImplementedError - - -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(BroadcastableModelInput): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. - """ - pass - - -class ModelRunnerInputBuilderBase(ABC, Generic[T]): - """A builder to create ModelRunnerInputBase objects. - """ - - @abstractmethod - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - raise NotImplementedError - - @abstractmethod - def add_seq_group(self, seq_group_metadata): - """TBA""" - raise NotImplementedError - - @abstractmethod - def build(self, *args, **kwargs) -> T: - """Build metadata with on-device tensors.""" - raise NotImplementedError - - -class ModelRunnerBase(ABC, Generic[T]): - """ - Model runner interface that abstracts a particular hardware and/or type of - model. Model execution may communicate data with model runners in other - processes, but it should not include control plane metadata communication. - - Each ModelRunnerBase subclass should define a corresponding - ModelRunnerInputBase subclass. - """ - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - - # Map of request_id -> generator used for seeded random sampling - generators: Dict[str, torch.Generator] = {} - - @abstractmethod - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> T: - """ - Make an instance of a ModelRunnerInputBase from the broadcasted tensor - dict. - """ - raise NotImplementedError - - @abstractmethod - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> T: - """ - Prepare the inputs to ModelRunnerBase.execute_model from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def get_model(self) -> nn.Module: - raise NotImplementedError - - def get_supported_generation_tasks(self) -> list[GenerationTask]: - model = self.get_model() - supported_tasks = list[GenerationTask]() - - if is_text_generation_model(model): - supported_tasks.append("generate") - - if supports_transcription(model): - if model.supports_transcription_only: - return ["transcription"] - - supported_tasks.append("transcription") - - return supported_tasks - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - tasks = list[SupportedTask]() - - if self.model_config.runner_type == "generate": - tasks.extend(self.get_supported_generation_tasks()) - - return tuple(tasks) - - def execute_model( - self, - model_input: T, - kv_caches: Optional[List[torch.Tensor]], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[List[SamplerOutput]]: - """ - Execute the model on the given input. - """ - raise NotImplementedError - - def get_generators(self, finished_request_ids: Optional[List[str]] = None): - """ - Return dict of per-request generators used for random sampling. - """ - - # Clean up generators from completed requests - if finished_request_ids: - for request_id in finished_request_ids: - self.generators.pop(request_id, None) - - return self.generators - - -class ModelRunnerWrapperBase: - """ - The whole point of this class is to lazily initialize the model_runner. - """ - - def __init__( - self, - model_runner: ModelRunnerBase, - ) -> None: - self.model_runner: ModelRunnerBase = model_runner - - def __getattr__(self, attr): - return getattr(self.model_runner, attr) - - -class InputProcessingError(Exception): - """This exception is raised when an error occurs preparing the inputs for - a single sequence group. - This allows the engine to gracefully handle errors with a single sequence - group without having to fail the entire batch. - """ - - def __init__(self, request_id, message): - """request_id is the id of the offending sequence group""" - self.request_id = request_id - self.message = message - super().__init__(self.message) - - def __str__(self): - return "Failed to prepare inputs for sequence group with request id: " \ - f"{self.request_id}, Error: {self.message}" diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py deleted file mode 100644 index 12047bc390737..0000000000000 --- a/vllm/worker/worker.py +++ /dev/null @@ -1,666 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A GPU worker class.""" -import gc -import os -from contextlib import nullcontext -from typing import Dict, List, Optional, Set, Tuple, Type, Union - -import torch -import torch.distributed - -import vllm.envs as envs -from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.device_allocator.cumem import CuMemAllocator -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, - memory_profiling) -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class Worker(LocalOrDistributedWorkerBase): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single GPU. The worker is responsible for - maintaining the KV cache and executing the model on the GPU. In case of - distributed inference, each worker is assigned a partition of the model. - """ - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - ) -> None: - WorkerBase.__init__(self, vllm_config) - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_config = self.speculative_config - model_config = self.model_config - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.hf_config.model_type == - model_config.hf_config.model_type) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ("medusa", - "mlp_speculator", - "eagle", - "deepseek_mtp", - "glm4_moe_mtp", - "mimo_mtp", - "ernie_mtp", - "qwen3_next_mtp")) \ - else {"return_hidden_states": True} - - self.model_runner: GPUModelRunnerBase = ModelRunner( - vllm_config=self.vllm_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - **speculative_args, - ) - if model_runner_cls is not None: - self.model_runner = model_runner_cls(self.model_runner) - - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] - self.gpu_cache: Optional[List[List[torch.Tensor]]] = None - self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} - - # Buffers saved before sleep - self._sleep_saved_buffers: Dict[str, torch.Tensor] = {} - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() - - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - # only print profiler results on rank 0 - if self.local_rank == 0: - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) - - def sleep(self, level: int = 1) -> None: - free_bytes_before_sleep = torch.cuda.mem_get_info()[0] - - # Save the buffers before level 2 sleep - if level == 2: - model = self.model_runner.model - self._sleep_saved_buffers = { - name: buffer.cpu().clone() - for name, buffer in model.named_buffers() - } - - allocator = CuMemAllocator.get_instance() - allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) - free_bytes_after_sleep, total = torch.cuda.mem_get_info() - freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep - used_bytes = total - free_bytes_after_sleep - assert freed_bytes >= 0, "Memory usage increased after sleeping." - logger.info( - "Sleep mode freed %.2f GiB memory, " - "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, - used_bytes / GiB_bytes) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - allocator = CuMemAllocator.get_instance() - allocator.wake_up(tags=tags) - - # Restore the buffers after level 2 sleep - if len(self._sleep_saved_buffers): - model = self.model_runner.model - for name, buffer in model.named_buffers(): - if name in self._sleep_saved_buffers: - buffer.data.copy_(self._sleep_saved_buffers[name].data) - self._sleep_saved_buffers = {} - - def init_device(self) -> None: - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) - - _check_if_gpu_supports_dtype(self.model_config.dtype) - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - self.baseline_snapshot = MemorySnapshot() - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") - context = allocator.use_memory_pool(tag="weights") - else: - context = nullcontext() - with context: - self.model_runner.load_model() - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self.model_runner.save_sharded_state( - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) - - @torch.inference_mode() - def determine_available_kv_cache_memory(self, - total_gpu_memory: int) -> float: - if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: - # still need a profile run which compiles the model for - # max_num_batched_tokens - self.model_runner.profile_run() - - GiB = lambda b: b / GiB_bytes - msg = ( - f"Initial free memory " - f"{GiB(self.baseline_snapshot.free_memory):.2f} " - f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " - "KV Cache as specified by kv_cache_memory_bytes config and " - "skipped memory profiling. This does does not respect the " - "gpu_memory_utilization config. Only use kv_cache_memory_bytes " - "config when you want manual control of KV cache memory " - "size. If OOM'ed, check the difference of initial free " - "memory between the current run and the previous run " - "where kv_cache_memory_bytes is suggested and update it " - "correspondingly.") - logger.info(msg) - return self.cache_config.kv_cache_memory_bytes - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with memory_profiling( - self.baseline_snapshot, - weights_memory=self.model_runner.model_memory_usage) as result: - self.model_runner.profile_run() - - self.non_torch_memory = result.non_torch_increase - self.peak_activation_memory = result.torch_peak_increase - - self._assert_memory_footprint_increased_during_profiling() - - self.requested_memory = total_gpu_memory * \ - self.cache_config.gpu_memory_utilization - - self.available_kv_cache_memory = (self.requested_memory - - result.non_kv_cache_memory) - - msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" - "the current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n" - "model weights take " - f"{(result.weights_memory / GiB_bytes):.2f}GiB;" - " non_torch_memory takes " - f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" - " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.") - - logger.info(msg) - return self.available_kv_cache_memory - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculates the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - Tip: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - available_kv_cache_memory = self.determine_available_kv_cache_memory( - total_gpu_memory) - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = self.get_cache_block_size_bytes() - if cache_block_size == 0: - num_gpu_blocks = 0 - num_cpu_blocks = 0 - else: - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - - # Final cleanup - gc.collect() - - return num_gpu_blocks, num_cpu_blocks - - def _assert_memory_footprint_increased_during_profiling(self): - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - free_gpu_memory, total = torch.cuda.mem_get_info() - cuda_memory = total - free_gpu_memory - assert self.baseline_snapshot.cuda_memory < cuda_memory, ( - "Error in memory profiling. " - f"Initial used memory {self.baseline_snapshot.cuda_memory}, " - f"currently used memory {cuda_memory}. " - f"This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks. - - This also warms up the model, which may record CUDA graphs. - """ - raise_if_cache_size_invalid( - num_gpu_blocks, self.cache_config.block_size, - self.cache_config.is_attention_free, - self.model_config.max_model_len, - self.parallel_config.pipeline_parallel_size) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(tag="kv_cache") - else: - context = nullcontext() - with context: - self._init_cache_engine() - self._warm_up_model() - - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.gpu_cache = [ - self.cache_engine[ve].gpu_cache - for ve in range(self.parallel_config.pipeline_parallel_size) - ] - - # Layer pairings for cross-layer KV sharing. - # If an Attention layer `layer_name` is in the keys of this dict, it - # means this layer will perform attention using the keys and values - # from the KV cache of `shared_kv_cache_layers[layer_name]`. - shared_kv_cache_layers: dict[str, str] = {} - - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - - for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - shared_kv_cache_layers[layer_name] = kv_tgt_layer - - bind_kv_cache(self.compilation_config.static_forward_context, - self.gpu_cache, shared_kv_cache_layers) - - def _warm_up_model(self) -> None: - # warm up sizes that are not in cudagraph capture sizes, - # but users still want to compile for better performance, - # e.g. for the max-num-batched token size in chunked prefill. - warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() - if not self.model_config.enforce_eager: - warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes - ] - for size in sorted(warmup_sizes, reverse=True): - logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size) - - cuda_graph_memory_bytes = 0 - if not self.model_config.enforce_eager: - cuda_graph_memory_bytes = self.model_runner.capture_model( - self.gpu_cache) - - if (self.cache_config.kv_cache_memory_bytes is None - and hasattr(self, "peak_activation_memory")): - # Suggests optimal kv cache memory size if we rely on - # memory_profiling to guess the kv cache memory size which - # provides peak_activation_memory and a few other memory - # consumption. `memory_profiling` does not consider - # CUDAGraph memory size and may not utilize all gpu memory. - # Users may want fine-grained control to specify kv cache - # memory size. - GiB = lambda b: round(b / GiB_bytes, 2) - non_kv_cache_memory = (self.model_runner.model_memory_usage + - self.peak_activation_memory + - self.non_torch_memory + - cuda_graph_memory_bytes) - - # empirically observed that the memory profiling may - # slightly underestimate the memory consumption. - # So leave a small buffer (=150MiB) to avoid OOM. - redundancy_buffer_memory = 150 * (1 << 20) - kv_cache_memory_bytes_to_gpu_limit = ( - self.baseline_snapshot.free_memory - non_kv_cache_memory - - redundancy_buffer_memory) - kv_cache_memory_bytes_to_requested_limit = ( - int(self.requested_memory) - non_kv_cache_memory - - redundancy_buffer_memory) - - msg = ( - f"Free memory on device " - f"({GiB(self.baseline_snapshot.free_memory)}/" - f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. " - f"Desired GPU memory utilization is " - f"({self.cache_config.gpu_memory_utilization}, " - f"{GiB(self.requested_memory)} GiB). " - f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " - f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " - f"for peak activation, {GiB(self.non_torch_memory)} GiB " - f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " - f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " - f"config with `--kv-cache-memory=" - f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " - f"requested memory, or `--kv-cache-memory=" - f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " - f"utilize gpu memory. Current kv cache memory in use is " - f"{int(self.available_kv_cache_memory)} bytes.") - logger.info(msg) - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - @property - def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return self.gpu_cache - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - virtual_engine = execute_model_req.virtual_engine - num_steps = execute_model_req.num_steps - num_seq_groups = len(execute_model_req.seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) - - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - num_steps=num_steps, - ) - - @torch.inference_mode() - def execute_worker(self, worker_input: WorkerInput) -> None: - virtual_engine = worker_input.virtual_engine - # Issue cache operations. - if (worker_input.blocks_to_swap_in is not None - and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine[virtual_engine].swap_in( - worker_input.blocks_to_swap_in) - if (worker_input.blocks_to_swap_out is not None - and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine[virtual_engine].swap_out( - worker_input.blocks_to_swap_out) - if (worker_input.blocks_to_copy is not None - and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) - - def _get_cached_seq_group_metadata( - self, - seq_group_metadata_list: List[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]], - finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: - """Return a list of cached Sequence Group Metadata after updating its - state. - - It is used because scheduler only sends delta to workers to reduce - the data payload size. The function also cleans up cache based on - a given `finished_request_ids`. - """ - new_seq_group_metadata_list = [] - for metadata_or_delta in seq_group_metadata_list: - request_id = metadata_or_delta.request_id - if request_id not in self._seq_group_metadata_cache: - # The first prefill. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[request_id] = metadata_or_delta - else: - # The first prefill is already cached. - if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): - self._seq_group_metadata_cache[request_id].apply_delta( - metadata_or_delta) - else: - # If metadata snapshot is sent again, it is - # preempted. Reset the cache because we need to start - # from scratch. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[ - request_id] = metadata_or_delta - - new_seq_group_metadata_list.append( - self._seq_group_metadata_cache[request_id]) - - # Clean up finished ids - for finished_id in finished_request_ids: - del self._seq_group_metadata_cache[finished_id] - - return new_seq_group_metadata_list - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Optional[List[SamplerOutput]]: - if execute_model_req is not None: - new_seq_group_metadata_list = self._get_cached_seq_group_metadata( - execute_model_req.seq_group_metadata_list, - execute_model_req.finished_requests_ids) - - execute_model_req.seq_group_metadata_list = ( - new_seq_group_metadata_list) - output = super()._execute_model_spmd(execute_model_req, - intermediate_tensors) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() - - @property - def max_model_len(self) -> int: - return self.model_config.max_model_len - - @property - def vocab_size(self) -> int: - return self.model_runner.vocab_size - - def get_cache_block_size_bytes(self) -> int: - """Get the size of the KV cache block size in bytes. - """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - - -def init_worker_distributed_environment( - vllm_config: VllmConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - parallel_config = vllm_config.parallel_config - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, - current_platform.dist_backend) - ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.decode_context_parallel_size) - - ensure_kv_transfer_initialized(vllm_config) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") - - -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, - max_model_len, pipeline_parallel_size) -> None: - if is_attention_free and num_gpu_blocks != 0: - raise ValueError("No memory should be allocated for the cache blocks " - f"for an attention-free model, but {num_gpu_blocks} " - "blocks are allocated.") - if not is_attention_free and num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) - if not is_attention_free and max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d0a56f6ff4637..20fabef4f19b9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,31 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import os -import time -from abc import abstractmethod -from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, + Union) import cloudpickle -import torch import torch.nn as nn -from vllm.config import (ObservabilityConfig, VllmConfig, - set_current_vllm_config) -from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.sequence import ExecuteModelRequest from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, update_environment_variables, warn_for_unimplemented_methods) -from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) +from vllm.v1.outputs import SamplerOutput logger = init_logger(__name__) @@ -141,356 +132,6 @@ class WorkerBase: return -class DelegateWorkerBase(WorkerBase): - """ - A class that delegates all methods to another WorkerBase instance. This is - useful for creating a WorkerBase that wraps another WorkerBase instance, - e.g. speculative decoding. - """ - worker: WorkerBase - - def __init__( - self, - *args, - **kwargs, - ) -> None: - vllm_config: VllmConfig = kwargs.get("vllm_config") - cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls) - self.worker = cls(*args, **kwargs) - - def init_device(self) -> None: - self.worker.init_device() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - return self.worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def load_model(self) -> None: - """Load model onto target device.""" - self.worker.load_model() - - def get_model(self) -> nn.Module: - return self.worker.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - return self.worker.execute_model(execute_model_req) - - def get_cache_block_size_bytes(self) -> int: - return self.worker.get_cache_block_size_bytes() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.worker.list_loras() - - def __getattr__(self, attr): - return getattr(self.worker, attr) - - -class LoRANotSupportedWorkerBase(WorkerBase): - """Partial implementation of WorkerBase that raises exceptions when LoRA - methods are invoked. - """ - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def remove_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def pin_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def list_loras(self) -> Set[int]: - raise ValueError(f"{type(self)} does not support LoRA") - - -@dataclasses.dataclass(frozen=True) -class WorkerInput: - """Local inputs to each worker. May contain device-specific data. These - fields should be broadcastable to other workers. - """ - - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - virtual_engine: int = 0 - num_steps: int = 1 - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], - tensor_dict: Dict[str, Any], - ) -> "WorkerInput": - """ - Pop fields from the given tensor_dict and populate a new instance of - WorkerInput. - """ - return cls( - num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - ) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. - """ - tensor_dict = { - "num_seq_groups": self.num_seq_groups, - "blocks_to_swap_in": self.blocks_to_swap_in, - "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, - } - - return tensor_dict - - -class LocalOrDistributedWorkerBase(WorkerBase): - """ - Partial implementation of WorkerBase that has a default `execute_model` - definition to perform metadata transfer between workers when in distributed - mode. Subclasses of this interface should use model runners that inherit - from ModelRunnerBase, and should only need to implement worker-local logic. - If custom control plane logic is needed to transfer metadata, or if the - model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. - """ - is_driver_worker: bool - model_runner: ModelRunnerBase - observability_config: Optional[ObservabilityConfig] = None - - @property - @abstractmethod - def do_metadata_broadcast(self) -> bool: - """ - Used by the default `execute_model` to check whether broadcast is - needed to transfer request inputs from the driver worker to other - workers in the TP group. If WorkerBase subclass only supports - single-worker execution, then this method should return False. - """ - raise NotImplementedError - - @property - @abstractmethod - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - """ - Gets the list of kv caches to pass to the worker's model runner. Each - element in the list is a kv cache corresponding to a particular virtual - engine (PP stream). Used by the default `execute_model`. If the worker's - model runner does not follow the ModelRunnerBase interface, then inherit - from WorkerBase instead. - """ - raise NotImplementedError - - @abstractmethod - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - """ - Prepare the inputs to WorkerBase.execute_worker from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def execute_worker(self, worker_input: WorkerInput) -> None: - """ - Process an execution request. - """ - raise NotImplementedError - - def _get_worker_input_from_broadcast( - self - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ Get the worker input from the broadcasted tensor dict. """ - assert self.do_metadata_broadcast - assert not self.is_driver_worker - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) - - kwargs = extract_previous_hidden_states(broadcast_data) - - return model_input, worker_input, kwargs - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: - """ Get the driver input and broadcast it to other workers. """ - assert self.is_driver_worker - - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - kwargs = extract_previous_hidden_states(execute_model_req) - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update(model_input.as_broadcastable_tensor_dict()) - broadcast_data.update(kwargs) - broadcast_tensor_dict(broadcast_data, src=0) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( # type: ignore - model_input, - async_callback=execute_model_req.async_callback) - - return model_input, worker_input, kwargs - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ - Prepare the inputs to ModelRunner and workers. - """ - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - return self._get_driver_input_and_broadcast(execute_model_req) - else: - return self._get_worker_input_from_broadcast() - - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() - - inputs = self.prepare_input(execute_model_req) - if inputs is None: - return None - - model_input, worker_input, kwargs = inputs - num_steps = worker_input.num_steps - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() - - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) - - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - assert isinstance(output, IntermediateTensors) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) - - # output is List[SamplerOutput] - return output - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None - ) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, ( - "_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - kwargs = extract_previous_hidden_states(execute_model_req) - - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) - - class WorkerWrapperBase: """ This class represents one process in an executor/engine. It is responsible @@ -636,23 +277,3 @@ class WorkerWrapperBase: def __getattr__(self, attr): return getattr(self.worker, attr) - - -def extract_previous_hidden_states( - data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ - Dict[str, torch.Tensor]: - """If data contains previous_hidden_states, extract it. This returns a dict - which can be used directly as additional kwargs in any following - execute_model calls. This is used in draft models like EAGLE.""" - output = {} - - # When called from non-driver worker, data is dict but when called from - # driver worker, data is ExecuteModelRequest. - if isinstance(data, dict): - if "previous_hidden_states" in data: - output["previous_hidden_states"] = data["previous_hidden_states"] - elif data.previous_hidden_states is not None: - output["previous_hidden_states"] = data.previous_hidden_states\ - .hidden_states - - return output