From 79acf80471860133ea48c975186059ae5f12f17b Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 2 May 2025 17:42:45 +0000 Subject: [PATCH] Fast decode prepare path for prepare_inputs logic Signed-off-by: Alexander Matveev --- examples/offline_inference/basic/basic.py | 4 +- vllm/envs.py | 3 + vllm/v1/worker/block_table.py | 28 +++- vllm/v1/worker/gpu_model_runner.py | 194 ++++++++++++++++++++++ 4 files changed, 225 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb48346..4c17afd0c2ec2 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -10,12 +10,12 @@ prompts = [ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=10) def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model="facebook/opt-125m", disable_cascade_attn=True) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5b..9d051bc7fb755 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -85,6 +85,7 @@ if TYPE_CHECKING: VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True + VLLM_ENABLE_V1_ADVANCE_STEP: bool = False VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False Q_SCALE_CONSTANT: int = 200 @@ -600,6 +601,8 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), "VLLM_DISABLE_COMPILE_CACHE": lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + "VLLM_ENABLE_V1_ADVANCE_STEP": + lambda: bool(int(os.getenv("VLLM_ENABLE_V1_ADVANCE_STEP", "0"))), # If set, vllm will run in development mode, which will enable # some additional endpoints for developing and debugging, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7d4082b73992b..82f197dbc52ae 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -3,6 +3,7 @@ import numpy as np import torch +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -36,6 +37,9 @@ class BlockTable: self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.prev_num_reqs = 0 + self.is_updated = True + def append_row( self, block_ids: list[int], @@ -48,16 +52,22 @@ class BlockTable: self.num_blocks_per_row[row_idx] += num_blocks self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.is_updated = True + def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 self.append_row(block_ids, row_idx) + self.is_updated = True + def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] self.block_table_np[tgt, :num_blocks] = self.block_table_np[ src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks + self.is_updated = True + def swap_row(self, src: int, tgt: int) -> None: num_blocks_src = self.num_blocks_per_row[src] num_blocks_tgt = self.num_blocks_per_row[tgt] @@ -66,14 +76,28 @@ class BlockTable: self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + self.is_updated = True + def commit(self, num_reqs: int) -> None: - self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], - non_blocking=True) + if envs.VLLM_ENABLE_V1_ADVANCE_STEP: + # Incremental copy + if self.prev_num_reqs != num_reqs or self.is_updated: + self.block_table[:num_reqs].copy_( + self.block_table_cpu[:num_reqs], non_blocking=True) + + self.prev_num_reqs = num_reqs + self.is_updated = False + else: + # Always copy + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) def clear(self) -> None: self.block_table.fill_(0) self.block_table_cpu.fill_(0) + self.is_updated = True + def get_device_tensor(self) -> torch.Tensor: """Ruturns the device tensor of the block table.""" return self.block_table diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e3d8b94fe9d7e..94fbd65aaed11 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,6 +10,7 @@ import torch import torch.distributed import torch.nn as nn +import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, @@ -142,6 +143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + if envs.VLLM_ENABLE_V1_ADVANCE_STEP: + logger.info("Advance_step is enabled") + if self.cascade_attn_enabled: + logger.warning( + "Disabling cascade attn (since advance_step is on)") + self.cascade_attn_enabled = False + else: + logger.info("Advance_step is disabled") + # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope @@ -271,16 +281,51 @@ class GPUModelRunner(LoRAModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping_gpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.query_start_loc_gpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) self.seq_lens_cpu = torch.zeros(self.max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + self.seq_lens_gpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + + # Cached + self.prev_num_reqs = 0 + self.req_indices_gpu = torch.arange(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + + self.req_indices_block_table_offsets_gpu = ( + self.req_indices_gpu * self.max_num_blocks_per_req) + + self.num_scheduled_tokens_gpu = torch.ones(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + + self.cu_num_tokens_gpu = torch.cumsum(self.num_scheduled_tokens_gpu, 0) + + self.query_start_loc_gpu[0] = 0 + self.query_start_loc_gpu[1:self.max_num_reqs + + 1] = self.cu_num_tokens_gpu + + self.logits_indices_gpu = self.query_start_loc_gpu[1:] - 1 + + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_attn_metadata = None + self.is_first_advance_decode = True def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -485,6 +530,119 @@ class GPUModelRunner(LoRAModelRunnerMixin): if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() + def _advance_decode_step( + self, + scheduler_output, + num_scheduled_tokens, + ): + # print(" -- inside advance_decode_step") + + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens == num_reqs + + # TODO: Add if needed + # Get request indices. + # E.g., num_reqs == 3 -> [0, 1, 2] + # req_indices_gpu = self.req_indices_gpu[:num_reqs] + # Get cu_sums + # cu_num_tokens = self.cu_num_tokens_gpu[:num_reqs] + + # Increment positions + positions_gpu = self.positions[:total_num_scheduled_tokens] + positions_gpu[:total_num_scheduled_tokens] += 1 + + # TODO: Verify MROPE is ok here + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + # Set next tokens + # (prev iteration tokens are cached in prev_sampled_token_ids tensor) + assert self.prev_sampled_token_ids is not None + self.input_ids[:total_num_scheduled_tokens] = \ + self.prev_sampled_token_ids[:,0] + + # Calculate the slot mapping + block_table_indices_gpu = ( + self.req_indices_block_table_offsets_gpu[:num_reqs] + + positions_gpu // self.block_size) + block_table_gpu = self.input_batch.block_table.get_device_tensor() + # Note: The block table tensor is async copied from CPU to GPU + # (inside the .commit() call) if was previously modified + block_numbers_gpu = block_table_gpu.flatten()[block_table_indices_gpu] + + block_offsets_gpu = positions_gpu % self.block_size + + slot_mapping_gpu = self.slot_mapping_gpu[:total_num_scheduled_tokens] + slot_mapping_gpu[:] = (block_numbers_gpu * self.block_size + + block_offsets_gpu) + + # Prepare the attention metadata. + + # query_start_loc is always the same for all decode iterations + query_start_loc_gpu = self.query_start_loc_gpu[:num_reqs + 1] + + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + + # TODO: Add cascade attn support + # Verify cascade attention is disabled + assert not self.cascade_attn_enabled + + # TODO: Add support for other attn backends + assert self.prev_attn_metadata is not None + assert isinstance(self.prev_attn_metadata, FlashAttentionMetadata) + + attn_metadata = self.prev_attn_metadata + attn_metadata.max_seq_len += 1 + attn_metadata.query_start_loc = query_start_loc_gpu + attn_metadata.seq_lens += 1 + attn_metadata.slot_mapping = slot_mapping_gpu + + # print("attn_metadata.seq_lens: shape = {} data = {}".format( + # attn_metadata.seq_lens.shape, attn_metadata.seq_lens)) + + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = self.logits_indices_gpu[:num_reqs] + spec_decode_metadata = None + else: + # TODO: Check if spec_decode can be enabled here + raise Exception("advance_step has no support for spec_decode yet") + # # Get the number of draft tokens for each request. + # # Iterate over the dictionary rather than all requests since + # # not all requests have draft tokens. + # num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + # for req_id, draft_token_ids in ( + # scheduler_output.scheduled_spec_decode_tokens.items()): + # req_idx = self.input_batch.req_id_to_index[req_id] + # num_draft_tokens[req_idx] = len(draft_token_ids) + + # spec_decode_metadata = self._calc_spec_decode_metadata( + # num_draft_tokens, cu_num_tokens) + # logits_indices = spec_decode_metadata.logits_indices + + # Hot-Swap lora model + if self.lora_config: + # TODO: Check if this works + raise Exception("advance_step has no LORA support yet") + self.set_active_loras(self.input_batch, num_scheduled_tokens) + + return attn_metadata, logits_indices, spec_decode_metadata + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -505,6 +663,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) + # Determine if advance step can be used + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + + is_flash_attn = self.prev_attn_metadata is not None and isinstance( + self.prev_attn_metadata, FlashAttentionMetadata) + + is_advance_decode = (envs.VLLM_ENABLE_V1_ADVANCE_STEP + and self.prev_num_reqs == num_reqs + and max_num_scheduled_tokens == 1 + and not use_spec_decode + and not self.cascade_attn_enabled + and is_flash_attn) + + if is_advance_decode: + if self.is_first_advance_decode: + # The first time advance_step can be used, + # we run the usual prepare, so that positions tensor + # is initialized + self.is_first_advance_decode = False + else: + # This is the fast-path advance_step + # (all tensors are on the GPU and are updated on the GPU) + (attn_metadata, logits_indices, + spec_decode_metadata) = self._advance_decode_step( + scheduler_output, num_scheduled_tokens) + return attn_metadata, logits_indices, spec_decode_metadata + else: + self.is_first_advance_decode = True + + self.prev_num_reqs = num_reqs + # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], @@ -523,6 +713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, @@ -599,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_query_len=max_num_scheduled_tokens, common_prefix_len=common_prefix_len, ) + self.prev_attn_metadata = attn_metadata use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -1177,6 +1369,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids + self.prev_sampled_token_ids = sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens.