From 2ecf7b175703de020943b33532baaf6a31f69d3a Mon Sep 17 00:00:00 2001 From: William Lin Date: Wed, 14 Aug 2024 12:32:45 -0700 Subject: [PATCH] [core] [3/N] multi-step args and sequence.py (#7452) --- vllm/config.py | 14 +++++++++- vllm/core/scheduler.py | 5 ++++ vllm/engine/arg_utils.py | 28 ++++++++++++++++--- vllm/sequence.py | 58 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 15d17d5e42a54..b564a0c68cef8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -847,7 +847,8 @@ class SchedulerConfig: delay_factor: float = 0.0, enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, - preemption_mode: Optional[str] = None) -> None: + preemption_mode: Optional[str] = None, + num_scheduler_steps: int = 1) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -876,6 +877,7 @@ class SchedulerConfig: self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode + self.num_scheduler_steps = num_scheduler_steps self._verify_args() def _verify_args(self) -> None: @@ -901,6 +903,16 @@ class SchedulerConfig: f"({self.num_lookahead_slots}) must be greater than or " "equal to 0.") + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + @property + def is_multi_step(self) -> bool: + return self.num_scheduler_steps > 1 + class DeviceConfig: device: Optional[torch.device] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b16850c7eb9f8..6ed75a6e2ea6b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -805,6 +805,9 @@ class Scheduler: curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) + seq_group.init_multi_step( + num_scheduler_steps=self._get_num_lookahead_slots( + is_prefill=True) + 1) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -1108,6 +1111,7 @@ class Scheduler: computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, + state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1184,6 +1188,7 @@ class Scheduler: slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 76bd3b630c54b..d99387542da18 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -115,6 +115,7 @@ class EngineArgs: lora_dtype: str = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' + num_scheduler_steps: int = 1 ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -543,6 +544,11 @@ class EngineArgs: "tpu", "xpu" ], help='Device type for vLLM execution.') + parser.add_argument('--num-scheduler-steps', + type=int, + default=1, + help=('Maximum number of forward steps per ' + 'scheduler call.')) parser.add_argument( '--scheduler-delay-factor', @@ -858,18 +864,34 @@ class EngineArgs: disable_logprobs=self.disable_logprobs_during_spec_decoding, ) + if self.num_scheduler_steps > 1: + raise NotImplementedError("Multi-step is not yet supported.") + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill: + raise ValueError("Chunked prefill is not supported with " + "multi-step (--num-scheduler-steps > 1)") + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, use_v2_block_manager=self.use_v2_block_manager, - num_lookahead_slots=(self.num_lookahead_slots - if speculative_config is None else - speculative_config.num_lookahead_slots), + num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/sequence.py b/vllm/sequence.py index 7349bc6f13bd6..b83e345235cdd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union, cast) +import numpy import torch from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs @@ -489,6 +490,19 @@ class Sequence: f"num_blocks={self.n_blocks}, ") +@dataclass +class SequenceGroupState: + """Mutable state tied to a specific sequence group""" + + # for multi-step decoding + num_steps: int = 1 + current_step: int = 0 + + @property + def remaining_steps(self) -> int: + return self.num_steps - self.current_step + + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -534,6 +548,7 @@ class SequenceGroup: time_in_queue=None) self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params self.prompt_adapter_request = prompt_adapter_request @@ -588,6 +603,10 @@ class SequenceGroup: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 + def init_multi_step(self, num_scheduler_steps: int) -> None: + self.state.num_steps = num_scheduler_steps + self.state.current_step = 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -756,6 +775,7 @@ class SequenceGroupMetadata: lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None @@ -781,6 +801,7 @@ class SequenceGroupMetadata: token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, + state: Optional[SequenceGroupState] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, @@ -796,6 +817,7 @@ class SequenceGroupMetadata: self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data + self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size @@ -834,6 +856,10 @@ class SequenceGroupMetadata: assert self._token_chunk_size is not None return self._token_chunk_size + def finish_step(self) -> None: + assert self.state.current_step < self.state.num_steps + self.state.current_step += 1 + class SequenceOutput: """The model output associated with a sequence. @@ -971,6 +997,7 @@ class SamplerOutput: # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None + sampled_token_ids_numpy: Optional[numpy.ndarray] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None @@ -1112,6 +1139,33 @@ class ExecuteModelRequest: num_steps: int = 1 # Finished request ids since last step. finished_requests_ids: List[str] = field(default_factory=list) + # The last sampled token ids for multi step decoding. + last_sampled_token_ids: Optional[torch.Tensor] = None + + @property + def is_first_multi_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + return first_seq_group.state.current_step == 0 + + @property + def is_last_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + num_steps = first_seq_group.state.num_steps + current_step = first_seq_group.state.current_step + return num_steps - current_step == 1 + + @property + def current_step(self) -> int: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + return self.seq_group_metadata_list[0].state.current_step def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -1127,4 +1181,6 @@ class ExecuteModelRequest: running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids) + finished_requests_ids=self.finished_requests_ids, + last_sampled_token_ids=self.last_sampled_token_ids.clone() + if self.last_sampled_token_ids is not None else None)