diff --git a/requirements-common.txt b/requirements-common.txt index b0e599a5e5af..b6bed8a73d8c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq +msgspec librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7aed0d5e1fa6..7c62de9fa9e3 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`. import pytest from prometheus_client import REGISTRY +import vllm.envs as envs from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT) @@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, ( "tests/basic_correctness/test_preemption.py`") +@pytest.fixture +def worker_use_ray() -> bool: + # When SPMD worker is used, use ray_use_worker=True + # to test delta input optimization works with preemption. + return envs.VLLM_USE_RAY_SPMD_WORKER + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [96]) @@ -36,6 +44,7 @@ def test_chunked_prefill_recompute( dtype: str, max_tokens: int, chunked_prefill_token_size: int, + worker_use_ray: bool, ) -> None: """Ensure that chunked prefill works with preemption.""" max_num_seqs = min(chunked_prefill_token_size, 256) @@ -54,6 +63,7 @@ def test_chunked_prefill_recompute( max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, max_num_seqs=max_num_seqs, + worker_use_ray=worker_use_ray, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt @@ -79,6 +89,7 @@ def test_preemption( model: str, dtype: str, max_tokens: int, + worker_use_ray: bool, ) -> None: """By default, recompute preemption is enabled""" @@ -89,6 +100,7 @@ def test_preemption( model, dtype=dtype, disable_log_stats=False, + worker_use_ray=worker_use_ray, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt @@ -132,6 +144,7 @@ def test_swap( dtype: str, max_tokens: int, beam_width: int, + worker_use_ray: bool, ) -> None: """Use beam search enables swapping.""" example_prompts = example_prompts[:1] @@ -144,6 +157,7 @@ def test_swap( dtype=dtype, swap_space=10, disable_log_stats=False, + worker_use_ray=worker_use_ray, ) as vllm_model: vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) @@ -188,6 +202,7 @@ def test_swap_infeasible( dtype: str, max_tokens: int, beam_width: int, + worker_use_ray: bool, ) -> None: """Verify infeasible swap request will be ignored.""" BLOCK_SIZE = 16 @@ -204,6 +219,7 @@ def test_swap_infeasible( # decode blocks are not enough to finish. num_gpu_blocks_override=prefill_blocks + decode_blocks, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + worker_use_ray=worker_use_ray, ) as vllm_model: sampling_params = SamplingParams(n=beam_width, use_beam_search=True, @@ -230,6 +246,7 @@ def test_preemption_infeasible( model: str, dtype: str, max_tokens: int, + worker_use_ray: bool, ) -> None: """Verify infeasible preemption request will be ignored.""" BLOCK_SIZE = 16 @@ -244,6 +261,7 @@ def test_preemption_infeasible( # ignored instead of hanging forever. num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + worker_use_ray=worker_use_ray, ) as vllm_model: sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py new file mode 100644 index 000000000000..d604e5250a3f --- /dev/null +++ b/tests/core/test_serialization.py @@ -0,0 +1,33 @@ +import msgspec + +from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.sequence import ExecuteModelRequest + +from ..spec_decode.utils import create_batch + + +def test_msgspec_serialization(): + num_lookahead_slots = 4 + seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=num_lookahead_slots, + running_queue_size=4) + + encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + req = decoder.decode(encoder.encode(execute_model_req)) + expected = execute_model_req.seq_group_metadata_list + actual = req.seq_group_metadata_list + assert (len(expected) == len(actual)) + expected = expected[0] + actual = actual[0] + + assert expected.block_tables == actual.block_tables + assert expected.is_prompt == actual.is_prompt + assert expected.request_id == actual.request_id + assert (expected.seq_data[0].prompt_token_ids == + actual.seq_data[0].prompt_token_ids) + assert (expected.seq_data[0].output_token_ids == + actual.seq_data[0].output_token_ids) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1de2ebab22db..e254686f269b 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") @pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, test_suite", [ + "model, distributed_executor_backend, attention_backend, " + "test_suite", [ ("facebook/opt-125m", "ray", "", "L4"), ("facebook/opt-125m", "mp", "", "L4"), ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 10921a3852f8..262845f19822 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py ``` """ +import os + import pytest from vllm.utils import cuda_device_count_stateless @@ -30,6 +32,11 @@ def test_models( model: str, distributed_executor_backend: str, ) -> None: + if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa + assert distributed_executor_backend == "ray" + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" dtype = "half" max_tokens = 5 diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 74e7486e8012..820fb554888f 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,5 +1,6 @@ import itertools import random +from array import array from typing import Dict, List, Optional, Tuple from unittest.mock import Mock, patch @@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import Counter, is_pin_memory_available @@ -56,7 +58,9 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sequence_data(num_input=3, num_generated=0): seq_data = SequenceData( - random.choices(range(0, VOCAB_SIZE), k=num_input)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, + random.choices(range(0, VOCAB_SIZE), k=num_input))) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), k=num_generated) @@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: + SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + [1, 2, 3])) + }, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 30eb99f868bf..60b36a33d907 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,3 +1,4 @@ +from array import array from itertools import count from typing import Callable, Dict, List, Optional from typing import Sequence as GenericSequence @@ -9,7 +10,8 @@ import torch from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, + CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port @@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts( seq_data={ i: SequenceData( - prompt_token_ids=prompt_token_ids[:], - output_token_ids=cont_token_ids[:], + array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]), + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + cont_token_ids[:]), ), }, sampling_params=SamplingParams(temperature=0.0, ), diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 7d4af963e25c..1ce49a50688a 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,4 +1,5 @@ import random +from array import array from typing import Tuple from unittest.mock import patch @@ -8,7 +9,8 @@ import torch from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import is_pin_memory_available @@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data={ + 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) + }, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 3136402518b9..1ae349e808e0 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,6 +1,9 @@ +from array import array + import pytest -from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, + CompletionSequenceGroupOutput, SamplerOutput, SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4])) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 8a2e9b81580f..32bff22f66a8 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,10 +1,12 @@ +from array import array from typing import List import pytest import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import is_cpu from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -125,10 +127,12 @@ def test_prepare_prompt( # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + range(seq_len))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -319,10 +323,12 @@ def test_prepare_decode( # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 84502043cbd2..a20aa37bcc1e 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,3 +1,4 @@ +from array import array from typing import List import pytest @@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, + SequenceData, SequenceGroupMetadata) from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData(list(range(context_len))) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(list(range(seq_len))) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, + range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(context_len)) + prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) seq_data = SequenceData(prompt_toks) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index f98adeba1c70..2bb17fdc0110 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass -@dataclass class AdapterRequest(ABC): """ Base class for adapter requests. diff --git a/vllm/config.py b/vllm/config.py index beb77f2bd905..a5a9984a0114 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -770,8 +770,8 @@ class ParallelConfig: self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size + if worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" @@ -867,6 +867,11 @@ class SchedulerConfig: swapping. However, when the sequence group has multiple sequences (e.g., beam search), recomputation is not currently supported. In such a case, we use swapping instead. + send_delta_data: Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1 + """ def __init__(self, @@ -879,7 +884,8 @@ class SchedulerConfig: enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1) -> None: + num_scheduler_steps: int = 1, + send_delta_data: bool = False) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -909,6 +915,7 @@ class SchedulerConfig: self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps + self.send_delta_data = send_delta_data self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 287de6014967..802359d2283f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStatus) from vllm.utils import PyObjectCache logger = init_logger(__name__) @@ -363,8 +364,6 @@ class Scheduler: self.num_cumulative_preemption: int = 0 # Used to cache python objects - self._seq_group_metadata_cache: PyObjectCache = PyObjectCache( - seq_group_metadata_builder) self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( scheduler_running_outputs_builder) self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( @@ -1048,15 +1047,10 @@ class Scheduler: token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) - seq_group_metadata = self._seq_group_metadata_cache.get_object() - seq_group_metadata.seq_data.clear() - seq_group_metadata.block_tables.clear() - # seq_id -> SequenceData - seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data + seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers - block_tables: Dict[int, - List[int]] = seq_group_metadata.block_tables + block_tables: Dict[int, List[int]] = {} if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup @@ -1081,45 +1075,65 @@ class Scheduler: seq_group.get_seqs(status=SequenceStatus.RUNNING))) do_sample = True - if seq_group.is_prefill(): + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: seqs = seq_group.get_seqs() # Prefill has only 1 sequence. assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 # In the next iteration, all prompt tokens are not computed. # It means the prefill is chunked, and we don't need sampling. # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < + if (token_chunk_size + num_computed_tokens < seqs[0].data.get_len()): do_sample = False # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = seq_group.is_prefill() - - seq_group_metadata.__init__( - request_id=seq_group.request_id, - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - do_sample=do_sample, - pooling_params=seq_group.pooling_params, - token_chunk_size=token_chunk_size, - lora_request=seq_group.lora_request, - computed_block_nums=common_computed_block_nums, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - state=seq_group.state, - # `multi_modal_data` will only be present for the 1st comm - # between engine and worker. - # the subsequent comms can still use delta, but - # `multi_modal_data` will be None. - multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 else None, - prompt_adapter_request=seq_group.prompt_adapter_request, - ) + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups > 0 else None, + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) seq_group_metadata_list.append(seq_group_metadata) # Now that the batch has been created, we can assume all blocks in the @@ -1130,8 +1144,6 @@ class Scheduler: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group) - self._seq_group_metadata_cache.reset() - scheduler_time = time.perf_counter() - scheduler_start_time # Add this to scheduler time to all the sequences that are currently # running. This will help estimate if the scheduler is a significant diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cd1aeb904ff3..8fca2cc04995 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, Union) +import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, @@ -905,6 +906,8 @@ class EngineArgs: embedding_mode=model_config.embedding_mode, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 021f4f248430..fcf45a38b942 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -224,7 +224,6 @@ class LLMEngine: cache_config.enable_prefix_caching, ) # TODO(woosuk): Print more configs in debug mode. - from vllm.plugins import load_general_plugins load_general_plugins() diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py new file mode 100644 index 000000000000..c467115f124c --- /dev/null +++ b/vllm/executor/msgspec_utils.py @@ -0,0 +1,27 @@ +from array import array +from typing import Any, Type + +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE + + +def encode_hook(obj: Any) -> Any: + """Custom msgspec enc hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if isinstance(obj, array): + assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( + f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " + f"Given array has a type code of {obj.typecode}.") + return obj.tobytes() + + +def decode_hook(type: Type, obj: Any) -> Any: + """Custom msgspec dec hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if type is array: + deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) + deserialized.frombytes(obj) + return deserialized diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index fa3646012dd6..3a08ab4dbfd4 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -4,9 +4,12 @@ from collections import defaultdict from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import msgspec + import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor): # Create the parallel GPU workers. self._init_workers_ray(placement_group) + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + self.output_decoder = msgspec.msgpack.Decoder( + Optional[List[SamplerOutput]]) + def shutdown(self) -> None: if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() @@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ray_remote_kwargs) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + # Create the workers. driver_ip = get_ip() worker_wrapper_kwargs = self._get_worker_wrapper_args() @@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor): if self.forward_dag is None: self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - outputs = ray.get(self.forward_dag.execute(execute_model_req)) - return outputs[0] + serialized_data = self.input_encoder.encode(execute_model_req) + outputs = ray.get(self.forward_dag.execute(serialized_data)) + output = self.output_decoder.decode(outputs[0]) + return output def _run_workers( self, @@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): if self.forward_dag is None: self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - dag_future = await self.forward_dag.execute_async(execute_model_req) + serialized_data = self.input_encoder.encode(execute_model_req) + dag_future = await self.forward_dag.execute_async(serialized_data) outputs = await dag_future - return outputs[0] + return self.output_decoder.decode(outputs[0]) async def _driver_execute_model_async( self, diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index ab283467d478..ffc94d07ed39 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,6 +1,9 @@ from typing import List, Optional, Tuple, Union +import msgspec + from vllm.config import ParallelConfig +from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors @@ -24,6 +27,10 @@ try: # that thread. self.compiled_dag_cuda_device_set = False + self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + def get_node_ip(self) -> str: return get_ip() @@ -33,16 +40,26 @@ try: return node_id, gpu_ids def execute_model_spmd( - self, req_or_tuple: Union[ExecuteModelRequest, - Tuple[ExecuteModelRequest, - IntermediateTensors]]): + self, req_or_tuple: Union[bytes, + Tuple[bytes, + Optional[IntermediateTensors]]] + ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. Args: - req_or_tuple: The request to execute the model, or a tuple - containing the request and intermediate tensors. + req_or_tuple: A request or a tuple containing the + request and intermediate tensors. Intermediate tensors are + None unless if it is provided because it is > 0 pipeline + stage. The request is serialized by msgspec. """ + if isinstance(req_or_tuple, bytes): + serialized_req, intermediate_tensors = req_or_tuple, None + else: + serialized_req, intermediate_tensors = req_or_tuple + + execute_model_req = self.input_decoder.decode(serialized_req) + # TODO(swang): This is needed right now because Ray aDAG executes # on a background thread, so we need to reset torch's current # device. @@ -51,16 +68,14 @@ try: torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - if isinstance(req_or_tuple, tuple): - execute_model_req, intermediate_tensors = req_or_tuple - else: - execute_model_req = req_or_tuple - intermediate_tensors = None - output = self.worker._execute_model_spmd(execute_model_req, intermediate_tensors) + # Pipeline model request and output to the next pipeline stage. if isinstance(output, IntermediateTensors): - return execute_model_req, output + output = serialized_req, output + else: + output = self.output_encoder.encode(output) + return output ray_import_err = None diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 28ce0ef86e79..deb66f0b0cb3 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,4 +1,5 @@ import functools +from array import array from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, @@ -21,6 +22,10 @@ logger = init_logger(__name__) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) +# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE. +# We cannot import it here because of circular dependencies. +VLLM_TOKEN_ID_ARRAY_TYPE = "l" + @dataclass(frozen=True) class InputContext: @@ -118,7 +123,8 @@ class InputRegistry: # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData([0] * seq_len) + dummy_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5d791424fbe6..d770da4f2407 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,12 +1,15 @@ import warnings -from dataclasses import dataclass, field from typing import Optional +import msgspec + from vllm.adapter_commons.request import AdapterRequest -@dataclass -class LoRARequest(AdapterRequest): +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """ Request for a LoRA adapter. @@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest): lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ + __metaclass__ = AdapterRequest lora_name: str lora_int_id: int lora_path: str = "" - lora_local_path: Optional[str] = field(default=None, repr=False) + lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None __hash__ = AdapterRequest.__hash__ def __post_init__(self): - if 'lora_local_path' in self.__dict__: + if 'lora_local_path' in self.__struct_fields__: warnings.warn( "The 'lora_local_path' attribute is deprecated " "and will be removed in a future version. " diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index a6fd5f58b3cb..69e777152e3d 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,5 +1,6 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" +from array import array from typing import Optional, Union import torch @@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -53,8 +54,10 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 386dfeb5bb1e..8cfd3c267256 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,3 +1,4 @@ +from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 6776b93d126b..788d22db9d5a 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,3 +1,4 @@ +from array import array from functools import cached_property from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict) @@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal @@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index fcd360ce8fd7..24eeefdfccf0 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,5 +1,6 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" +from array import array from typing import Iterable, Optional, Tuple import torch @@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -53,8 +54,10 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index e8184e466c5b..2ef23819b69a 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,6 +16,7 @@ # limitations under the License. """ PyTorch Fuyu model.""" import math +from array import array from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict import torch @@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import (cached_get_image_processor, cached_get_tokenizer) -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings @@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): ncol, nrow = get_max_fuyu_image_feature_size() image_feature_size = get_max_fuyu_image_tokens(ctx) - image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow - token_ids = image_token_ids * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + image_token_ids = ( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol + + array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) return SequenceData(token_ids) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ef2323398abd..729bd27c334d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,6 +23,7 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re +from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, TypedDict, Union) @@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, cached_get_tokenizer) -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SamplerOutput, SequenceData) from .idefics2_vision_model import Idefics2VisionTransformer @@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - token_ids = [0] * seq_len + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len return SequenceData(token_ids) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 4df8c0b54201..426af7fee954 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,6 +2,7 @@ within a vision language model.""" import math +from array import array from typing import Iterable, Optional, Tuple import torch @@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) -from vllm.sequence import SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: @@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - token_ids = [image_token_id] * image_feature_size * num_images - token_ids += [0] * (seq_len - image_feature_size * num_images) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * image_feature_size + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size) return SequenceData(token_ids) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 94b4b1441682..a085779bc61a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple import torch from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, + SequenceGroupMetadata) from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (PyObjectCache, async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, @@ -505,9 +506,11 @@ class SamplingTensors: and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) prompt_tokens.extend( - array('l') for _ in range(prefill_len)) + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) output_tokens.extend( - array('l') for _ in range(prefill_len)) + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) if seq_group.do_sample: for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 3b95d73ddc2c..7461fb51989c 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,15 +1,18 @@ from typing import Any, Optional +import msgspec -class PoolingParams: + +class PoolingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """Pooling parameters for pooling. Attributes: additional_data: Any additional data needed for pooling. """ - - def __init__(self, additional_data: Optional[Any] = None): - self.additional_data = additional_data + additional_data: Optional[Any] = None def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py index c0c98cf72bba..775dd11db071 100644 --- a/vllm/prompt_adapter/request.py +++ b/vllm/prompt_adapter/request.py @@ -1,13 +1,17 @@ -from dataclasses import dataclass +import msgspec from vllm.adapter_commons.request import AdapterRequest -@dataclass -class PromptAdapterRequest(AdapterRequest): +class PromptAdapterRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + frozen=True): # type: ignore[call-arg] """ Request for a Prompt adapter. """ + __metaclass__ = AdapterRequest prompt_adapter_name: str prompt_adapter_id: int diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 04250c682cd2..7197b5139853 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,10 +2,10 @@ import copy from enum import IntEnum from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union +import msgspec import torch -from pydantic import Field from typing_extensions import Annotated from vllm.logger import init_logger @@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits to sample from.""" -class SamplingParams: +class SamplingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion @@ -112,87 +116,73 @@ class SamplingParams: (i.e., no truncation). """ - def __init__( - self, - n: int = 1, - best_of: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - seed: Optional[int] = None, - use_beam_search: bool = False, - length_penalty: float = 1.0, - early_stopping: Union[bool, str] = False, - stop: Optional[Union[str, List[str]]] = None, - stop_token_ids: Optional[List[int]] = None, - include_stop_str_in_output: bool = False, - ignore_eos: bool = False, - max_tokens: Optional[int] = 16, - min_tokens: int = 0, - logprobs: Optional[int] = None, - prompt_logprobs: Optional[int] = None, - detokenize: bool = True, - skip_special_tokens: bool = True, - spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, - ) -> None: - self.n = n - self.best_of = best_of if best_of is not None else n - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.repetition_penalty = repetition_penalty - if 0 < temperature < _MAX_TEMP: + n: int = 1 + best_of: Optional[int] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + seed: Optional[int] = None + use_beam_search: bool = False + length_penalty: float = 1.0 + early_stopping: Union[bool, str] = False + stop: Optional[Union[str, List[str]]] = None + stop_token_ids: Optional[List[int]] = None + ignore_eos: bool = False + max_tokens: Optional[int] = 16 + min_tokens: int = 0 + logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. + detokenize: bool = True + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + # Optional[List[LogitsProcessor]] type. We use Any here because + # Optional[List[LogitsProcessor]] type is not supported by msgspec. + logits_processors: Optional[Any] = None + include_stop_str_in_output: bool = False + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + + # The below fields are not supposed to be used as an input. + # They are set in post_init. + output_text_buffer_length: int = 0 + _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + + def __post_init__(self) -> None: + self.best_of = self.best_of or self.n + if 0 < self.temperature < _MAX_TEMP: logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", - temperature, _MAX_TEMP, _MAX_TEMP) - temperature = max(temperature, _MAX_TEMP) - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.min_p = min_p - if seed == -1: + self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature = max(self.temperature, _MAX_TEMP) + if self.seed == -1: self.seed = None else: - self.seed = seed - self.use_beam_search = use_beam_search - self.length_penalty = length_penalty - self.early_stopping = early_stopping - if stop is None: + self.seed = self.seed + if self.stop is None: self.stop = [] - elif isinstance(stop, str): - self.stop = [stop] + elif isinstance(self.stop, str): + self.stop = [self.stop] else: - self.stop = list(stop) - if stop_token_ids is None: + self.stop = list(self.stop) + if self.stop_token_ids is None: self.stop_token_ids = [] else: - self.stop_token_ids = list(stop_token_ids) - self.ignore_eos = ignore_eos - self.max_tokens = max_tokens - self.min_tokens = min_tokens - self.logprobs = 1 if logprobs is True else logprobs - self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs - # NOTE: This parameter is only exposed at the engine level for now. - # It is not exposed in the OpenAI API server, as the OpenAI API does - # not support returning only a list of token IDs. - self.detokenize = detokenize - self.skip_special_tokens = skip_special_tokens - self.spaces_between_special_tokens = spaces_between_special_tokens - self.logits_processors = logits_processors - self.include_stop_str_in_output = include_stop_str_in_output - self.truncate_prompt_tokens = truncate_prompt_tokens + self.stop_token_ids = list(self.stop_token_ids) + self.logprobs = 1 if self.logprobs is True else self.logprobs + self.prompt_logprobs = (1 if self.prompt_logprobs is True else + self.prompt_logprobs) + # Number of characters to hold back for stop string evaluation # until sequence is finished. - if self.stop and not include_stop_str_in_output: + if self.stop and not self.include_stop_str_in_output: self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 - else: - self.output_text_buffer_length = 0 self._verify_args() if self.use_beam_search: @@ -206,11 +196,12 @@ class SamplingParams: self.min_p = 0.0 self._verify_greedy_sampling() # eos_token_id is added to this by the engine - self.all_stop_token_ids = set(self.stop_token_ids) + self._all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") + assert isinstance(self.best_of, int) if self.best_of < self.n: raise ValueError(f"best_of must be greater than or equal to n, " f"got n={self.n} and best_of={self.best_of}.") @@ -257,6 +248,7 @@ class SamplingParams: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: @@ -290,6 +282,7 @@ class SamplingParams: "default value of 1.0 when not using beam search.") def _verify_greedy_sampling(self) -> None: + assert isinstance(self.best_of, int) if self.best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") @@ -303,7 +296,7 @@ class SamplingParams: if model_eos_token_id is not None: # Add the eos token id into the sampling_params to support # min_tokens processing. - self.all_stop_token_ids.add(model_eos_token_id) + self._all_stop_token_ids.add(model_eos_token_id) # Update eos_token_id for generation if (eos_ids := generation_config.get("eos_token_id")) is not None: @@ -315,7 +308,7 @@ class SamplingParams: # purposes. eos_ids.discard(model_eos_token_id) if eos_ids: - self.all_stop_token_ids.update(eos_ids) + self._all_stop_token_ids.update(eos_ids) if not self.ignore_eos: eos_ids.update(self.stop_token_ids) self.stop_token_ids = list(eos_ids) @@ -330,6 +323,10 @@ class SamplingParams: return SamplingType.RANDOM_SEED return SamplingType.RANDOM + @property + def all_stop_token_ids(self) -> Set[int]: + return self._all_stop_token_ids + def clone(self) -> "SamplingParams": """Deep copy excluding LogitsProcessor objects. diff --git a/vllm/sequence.py b/vllm/sequence.py index b83e345235cd..b15955cde76c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,10 +4,11 @@ import enum from abc import ABC, abstractmethod from array import array from collections import defaultdict -from dataclasses import dataclass, field -from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, - Union, cast) +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, + Tuple, Union, cast) +import msgspec import numpy import torch @@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: from vllm.inputs import LLMInputs - from vllm.multimodal import MultiModalDataDict - from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + from vllm.multimodal.base import MultiModalDataDict + +VLLM_TOKEN_ID_ARRAY_TYPE = "l" +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. @dataclass class Logprob: """Infos for supporting OpenAI compatible logprobs and token ranks. @@ -112,7 +118,23 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class SequenceData: +class SequenceDataDelta( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta SequenceData to send to workers per step.""" + # A new token to be appended to existing SequenceData. + new_output_token_ids: List[int] + # Overwriting existing `cumulative_logprob` + new_cumulative_logprob: float + # Overwriting existing `num_computed_tokens`. + new_num_computed_tokens: int + # Overwriting existing `stage`. + new_stage: SequenceStage + + +class SequenceData(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Data associated with a sequence. Args: @@ -125,40 +147,57 @@ class SequenceData: output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ + # NOTE: we cannot use Union[List, array] because msgspec cannot support + # union of 2 list types. + _prompt_token_ids: array + _output_token_ids: array = msgspec.field( + default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - def __init__( - self, - prompt_token_ids: List[int], - output_token_ids: Optional[List[int]] = None, - ) -> None: - self._prompt_token_ids = array('l', prompt_token_ids) - self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) - self._output_token_ids = array( - 'l', output_token_ids if output_token_ids is not None else []) + ### The below fields should not be passed as an argument ### + _cumulative_logprob: float = 0.0 + _prompt_token_ids_tuple: Tuple[int, + ...] = msgspec.field(default_factory=tuple) + # The number of tokens that are computed (that run against the model). + _num_computed_tokens: int = 0 + _stage: SequenceStage = SequenceStage.PREFILL + _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) - self.cumulative_logprob = 0.0 - # The number of tokens that are computed (that run against the model). - self._num_computed_tokens = 0 - self._stage: SequenceStage = SequenceStage.PREFILL + # It is used to get delta input. It is reset when `get_delta_and_reset` + # is called. + _new_appended_tokens: List[int] = msgspec.field(default_factory=list) + def __post_init__(self) -> None: + assert self._prompt_token_ids.typecode == "l" + assert self._output_token_ids.typecode == "l" + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( + self._prompt_token_ids) self._update_cached_all_tokens() def _update_cached_all_tokens(self): + assert isinstance(self._prompt_token_ids, array) + assert isinstance(self._output_token_ids, array) self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + self._output_token_ids) + @property + def cumulative_logprob(self) -> float: + return self._cumulative_logprob + @property def prompt_token_ids(self) -> Tuple[int, ...]: return self._prompt_token_ids_tuple @prompt_token_ids.setter def prompt_token_ids(self, new_prompt_token_ids) -> None: - self._prompt_token_ids = array('l', new_prompt_token_ids) - self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) - self._update_cached_all_tokens() + raise NotImplementedError @property def prompt_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ return self._prompt_token_ids @property @@ -166,18 +205,26 @@ class SequenceData: return tuple(self._output_token_ids) @output_token_ids.setter - def output_token_ids(self, new_output_token_ids) -> None: - self._output_token_ids = array('l', new_output_token_ids) + def output_token_ids(self, new_output_token_ids: List[int]) -> None: + self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids) self._update_cached_all_tokens() @property def output_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + assert isinstance(self._output_token_ids, array) return self._output_token_ids def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) + self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) - self.cumulative_logprob += logprob + self._cumulative_logprob += logprob def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) @@ -222,6 +269,7 @@ class SequenceData: """ self._num_computed_tokens = 0 self._stage = SequenceStage.PREFILL + self._new_appended_tokens = [] def get_num_uncomputed_tokens(self) -> int: """Return the number of prefill tokens that are not computed.""" @@ -241,6 +289,21 @@ class SequenceData: def get_output_token_ids(self) -> Tuple[int, ...]: return self.output_token_ids + def get_delta_and_reset(self) -> SequenceDataDelta: + delta = SequenceDataDelta(self._new_appended_tokens, + self._cumulative_logprob, + self.get_num_computed_tokens(), self.stage) + # Reset delta state. + self._new_appended_tokens = [] + return delta + + def apply_delta(self, delta: SequenceDataDelta): + self._num_computed_tokens = delta.new_num_computed_tokens + self._cumulative_logprob = delta.new_cumulative_logprob + self._stage = delta.new_stage + self._output_token_ids.extend(delta.new_output_token_ids) + self._cached_all_token_ids.extend(delta.new_output_token_ids) + @property def stage(self) -> SequenceStage: return self._stage @@ -248,8 +311,9 @@ class SequenceData: def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " - f"output_token_ids={self._output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob})") + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()}") class Sequence: @@ -325,7 +389,8 @@ class Sequence: f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") - self.data = SequenceData(self.prompt_token_ids) + self.data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids)) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -490,8 +555,8 @@ class Sequence: f"num_blocks={self.n_blocks}, ") -@dataclass -class SequenceGroupState: +class SequenceGroupState(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] """Mutable state tied to a specific sequence group""" # for multi-step decoding @@ -647,14 +712,19 @@ class SequenceGroup: if self.sampling_params and self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. - return self.sampling_params.best_of + best_of = self.sampling_params.best_of + assert isinstance(best_of, int) + return best_of else: - if (self.sampling_params - and self.sampling_params.best_of > self.num_seqs()): - # At prompt stage, the sequence group is not yet filled up - # and only have one sequence running. However, in the - # generation stage, we will have `best_of` sequences running. - return self.sampling_params.best_of + if self.sampling_params: + best_of = self.sampling_params.best_of + assert isinstance(best_of, int) + if best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences + # running. + return best_of # At sampling stages, return the number of actual sequences # that are not finished yet. return self.num_unfinished_seqs() @@ -757,7 +827,32 @@ class SequenceGroup: f"num_seqs={len(self.seqs)})") -class SequenceGroupMetadata: +class SequenceGroupMetadataDelta( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta of SequenceGroupMetadata. + + After sending the first SequenceGroupMetadata, vLLM scheduler + only sends delta to reduce the data payload size. + """ + seq_data_delta: Dict[int, SequenceDataDelta] + request_id: str + block_tables: Dict[int, List[int]] + is_prompt: bool + do_sample: bool = True + token_chunk_size: Optional[int] = None + computed_block_nums: Optional[List[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + + +class SequenceGroupMetadata( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] """Metadata for a sequence group. Used to create `AttentionMetadata`. Args: @@ -789,52 +884,39 @@ class SequenceGroupMetadata: prompt_adapter_request: Prompt Adapter request. """ - def __init__( - self, - request_id: str, - is_prompt: bool, - seq_data: Dict[int, SequenceData], - sampling_params: SamplingParams, - block_tables: Dict[int, List[int]], - do_sample: bool = True, - pooling_params: Optional[PoolingParams] = None, - token_chunk_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - computed_block_nums: Optional[List[int]] = None, - state: Optional[SequenceGroupState] = None, - multi_modal_data: Optional["MultiModalDataDict"] = None, - encoder_seq_data: Optional[SequenceData] = None, - cross_block_table: Optional[List[int]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> None: - self.request_id = request_id - self.is_prompt = is_prompt - self.seq_data = seq_data - self.sampling_params = sampling_params - self.block_tables = block_tables - self.pooling_params = pooling_params - self.lora_request = lora_request - self.prompt_adapter_request = prompt_adapter_request - self.computed_block_nums = computed_block_nums - self.multi_modal_data = multi_modal_data - self.state = SequenceGroupState() if state is None else state - self.encoder_seq_data = encoder_seq_data - self.cross_block_table = cross_block_table - self._token_chunk_size = token_chunk_size - self.do_sample = do_sample + request_id: str + is_prompt: bool + seq_data: Dict[int, SequenceData] + sampling_params: SamplingParams + block_tables: Dict[int, List[int]] + do_sample: bool = True + pooling_params: Optional[PoolingParams] = None + lora_request: Optional[LoRARequest] = None + computed_block_nums: Optional[List[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + # "MultiModalDataDict" types. We have to use Any due to msgspec + # doesn't allow to have union of 2 different dicts. + multi_modal_data: Optional[Any] = None + encoder_seq_data: Optional[SequenceData] = None + cross_block_table: Optional[List[int]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + token_chunk_size: Optional[int] = None - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - self.num_speculative_tokens = None + ### Stateful fields that are lazily defined. ### + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + num_speculative_tokens: Optional[int] = None - if seq_data is not None and self._token_chunk_size is None: - if is_prompt: - self._token_chunk_size = next(iter( - seq_data.values())).get_len() + def __post_init__(self): + if self.seq_data is not None and self.token_chunk_size is None: + if self.is_prompt: + self.token_chunk_size = next(iter( + self.seq_data.values())).get_len() else: - self._token_chunk_size = 1 + self.token_chunk_size = 1 @property def lora_int_id(self) -> int: @@ -850,18 +932,26 @@ class SequenceGroupMetadata: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 - @property - def token_chunk_size(self) -> int: - """Return the number of tokens to be processed (chunk size).""" - assert self._token_chunk_size is not None - return self._token_chunk_size + def apply_delta(self, + sequence_group_metadata_delta: SequenceGroupMetadataDelta): + for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): + self.seq_data[id].apply_delta(delta) + assert self.request_id == sequence_group_metadata_delta.request_id + self.block_tables = sequence_group_metadata_delta.block_tables + self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size + self.do_sample = sequence_group_metadata_delta.do_sample + self.is_prompt = sequence_group_metadata_delta.is_prompt def finish_step(self) -> None: + assert self.state is not None assert self.state.current_step < self.state.num_steps self.state.current_step += 1 -class SequenceOutput: +class SequenceOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """The model output associated with a sequence. Args: @@ -871,16 +961,9 @@ class SequenceOutput: logprobs: The logprobs of the output token. (Token id -> logP(x_i+1 | x_0, ..., x_i)) """ - - def __init__( - self, - parent_seq_id: int, - output_token: int, - logprobs: Dict[int, Logprob], - ) -> None: - self.parent_seq_id = parent_seq_id - self.output_token = output_token - self.logprobs = logprobs + parent_seq_id: int + output_token: int + logprobs: Dict[int, Logprob] def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " @@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC): pass -class CompletionSequenceGroupOutput(SequenceGroupOutput): +class CompletionSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + __metaclass__ = SequenceGroupOutput """The model output associated with a completion sequence group.""" - - def __init__( - self, - samples: List[SequenceOutput], - prompt_logprobs: Optional[PromptLogprobs], - ) -> None: - self.samples = samples - # Prompt logprob for each prompt query token. - self.prompt_logprobs = prompt_logprobs + samples: List[SequenceOutput] + # Prompt logprob for each prompt query token. + prompt_logprobs: Optional[PromptLogprobs] def __repr__(self) -> str: return (f"CompletionSequenceGroupOutput(samples={self.samples}, " @@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput): and self.prompt_logprobs == other.prompt_logprobs) -class EmbeddingSequenceGroupOutput(SequenceGroupOutput): +class EmbeddingSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] +): """The model output associated with an embedding sequence group.""" - - def __init__( - self, - embeddings: List[float], - ) -> None: - self.embeddings = embeddings + __metaclass__ = SequenceGroupOutput + embeddings: List[int] def __repr__(self) -> str: return (f"EmbeddingSequenceGroupOutput(" @@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput): return self.embeddings == other.embeddings -@dataclass -class IntermediateTensors: +class IntermediateTensors( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. @@ -978,8 +1061,10 @@ class IntermediateTensors: return f"IntermediateTensors(tensors={self.tensors})" -@dataclass -class SamplerOutput: +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. @@ -1000,7 +1085,7 @@ class SamplerOutput: sampled_token_ids_numpy: Optional[numpy.ndarray] = None # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None @@ -1039,12 +1124,14 @@ class SamplerOutput: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") -@dataclass -class PoolerOutput: +class PoolerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """The output from a pooling operation in the embedding model.""" outputs: List[EmbeddingSequenceGroupOutput] - spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None def __getitem__(self, idx: int): return self.outputs[idx] @@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids( return seq_ids, request_id_seq_ids_mapping -class HiddenStates: +class HiddenStates(msgspec.Struct, array_like=True, + omit_defaults=True): # type: ignore[call-arg] """Hidden states corresponding to in-progress sequences. Used in speculative decoding to pass hidden states from the target model to the proposer model in the subsequent step. @@ -1091,42 +1179,53 @@ class HiddenStates: seq_ids are the sequence ids of each entry of the batch dimension of the hidden_states tensor""" - def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata], - hidden_states: torch.Tensor): - assert len(seq_group_metadata_list) == len(hidden_states) - self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list) - self.hidden_states: torch.Tensor = hidden_states + seq_group_metadata_list: List[SequenceGroupMetadata] + hidden_states: torch.Tensor + _seq_ids: List[int] = msgspec.field(default_factory=list) + + def __post_init__(self): + self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) + assert len(self.seq_group_metadata_list) == len(self.hidden_states) + + @property + def seq_ids(self) -> List[int]: + return self._seq_ids def update(self, seq_group_metadata_list: List[SequenceGroupMetadata], hidden_states: torch.Tensor) -> None: """Update hidden states from target model invocation.""" assert len(seq_group_metadata_list) == len(hidden_states) - self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) def prune(self, seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: """Prune to provided list of sequence ids.""" seq_ids = get_all_seq_ids(seq_group_metadata_list) - if seq_ids != self.seq_ids: + if seq_ids != self._seq_ids: # Batch contents changed - prune removed sequences. - index = [self.seq_ids.index(seq_id) for seq_id in seq_ids] + index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] self.hidden_states = self.hidden_states[index] - self.seq_ids = seq_ids + self._seq_ids = seq_ids -@dataclass -class ExecuteModelRequest: +class ExecuteModelRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] """The model execution request, containing CPU metadata only. The LLM engine should create an instance of this class for each request batch.""" # The sequence group metadata list. - seq_group_metadata_list: List[SequenceGroupMetadata] + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_swap_in: List[Tuple[int, + int]] = msgspec.field(default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_swap_out: List[Tuple[int, + int]] = msgspec.field(default_factory=list) # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) # Virtual engine ID for pipeline parallel. virtual_engine: int = 0 # The number of slots for lookahead decoding. @@ -1138,7 +1237,7 @@ class ExecuteModelRequest: # The number of forward steps to run. num_steps: int = 1 # Finished request ids since last step. - finished_requests_ids: List[str] = field(default_factory=list) + finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None @@ -1148,6 +1247,7 @@ class ExecuteModelRequest: # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None return first_seq_group.state.current_step == 0 @property @@ -1156,6 +1256,7 @@ class ExecuteModelRequest: # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None num_steps = first_seq_group.state.num_steps current_step = first_seq_group.state.current_step return num_steps - current_step == 1 @@ -1165,10 +1266,13 @@ class ExecuteModelRequest: # TODO(will) make this be able to handle batches with variable number of # steps assert len(self.seq_group_metadata_list) > 0 - return self.seq_group_metadata_list[0].state.current_step + state = self.seq_group_metadata_list[0].state + assert state is not None + return state.current_step def clone( - self, seq_group_metadata_list: List[SequenceGroupMetadata] + self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] ) -> "ExecuteModelRequest": """Clone the request with a new sequence group metadata list.""" return ExecuteModelRequest( diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 45eaeb51c5c0..aec4847b96c3 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,11 +1,13 @@ +from array import array from itertools import chain, count from typing import Iterator, List, Tuple import torch from vllm import SamplingParams -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, - SequenceGroupMetadata, get_all_seq_ids) +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, + SamplerOutput, SequenceData, SequenceGroupMetadata, + get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, @@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): input sequence. """ seq_data = seq_group_metadata.seq_data[seq_id] - prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_token_ids = seq_data.prompt_token_ids_array new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] new_seq_data_dict = { target_seq_id: SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, + prompt_token_ids, + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids), ), } # This is a hack. Technically, spec decoding should compute diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 9036d117041f..ad4e2dc879d7 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -1,7 +1,7 @@ import time -from dataclasses import dataclass from typing import Callable, Optional +import msgspec import torch from vllm.model_executor.layers.spec_decode_base_sampler import ( @@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.utils import is_pin_memory_available -@dataclass -class SpecDecodeWorkerMetrics: +class SpecDecodeWorkerMetrics( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] """Dataclass holding metrics emitted from the spec decode worker. """ diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ffe6216d3ed6..97be68934be4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import List, Optional, Set, Tuple, Type +from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.distributed @@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput, SequenceGroupMetadata, + SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase): self.cache_engine: List[CacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} def _is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model @@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase): and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + def _get_cached_seq_group_metadata( + self, + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]], + finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: + """Return a list of cached Sequence Group Metadata after updating its + state. + + It is used because scheduler only sends delta to workers to reduce + the data payload size. The function also cleans up cache based on + a given `finished_request_ids`. + """ + new_seq_group_metadata_list = [] + for metadata_or_delta in seq_group_metadata_list: + request_id = metadata_or_delta.request_id + if request_id not in self._seq_group_metadata_cache: + # The first prefill. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[request_id] = metadata_or_delta + else: + # The first prefill is already cached. + if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): + self._seq_group_metadata_cache[request_id].apply_delta( + metadata_or_delta) + else: + # If metadata snapshot is sent again, it is + # preempted. Reset the cache because we need to start + # from scratch. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[ + request_id] = metadata_or_delta + + new_seq_group_metadata_list.append( + self._seq_group_metadata_cache[request_id]) + + # Clean up finished ids + for finished_id in finished_request_ids: + del self._seq_group_metadata_cache[finished_id] + + return new_seq_group_metadata_list + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Optional[List[SamplerOutput]]: + if execute_model_req is not None: + new_seq_group_metadata_list = self._get_cached_seq_group_metadata( + execute_model_req.seq_group_metadata_list, + execute_model_req.finished_requests_ids) + + execute_model_req.seq_group_metadata_list = ( + new_seq_group_metadata_list) + output = super()._execute_model_spmd(execute_model_req, + intermediate_tensors) + return output + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request)