diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9ed3dec7f2695..2c8636a5aa087 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -31,6 +31,8 @@ if current_platform.is_cuda(): from vllm.vllm_flash_attn import (flash_attn_varlen_func, get_scheduler_metadata) +from vllm.v1.attention.backends.utils import slice_query_start_locs + logger = init_logger(__name__) @@ -322,23 +324,23 @@ class FlashAttentionMetadataBuilder: scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens + def build_slice(self, max_query_len: int, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + req_slice: slice, + token_slice: slice) -> FlashAttentionMetadata: + num_reqs = req_slice.stop - req_slice.start + num_tokens = token_slice.stop - token_slice.start + + max_seq_len = self.runner.seq_lens_np[req_slice].max() + query_start_loc = slice_query_start_locs( + common_attn_metadata.query_start_loc, req_slice) + seq_lens = common_attn_metadata.seq_lens[req_slice] block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] + block_table_tensor = block_table.get_device_tensor()[req_slice] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table.slot_mapping[token_slice].copy_( + block_table.slot_mapping_cpu[token_slice], non_blocking=True) + slot_mapping = block_table.slot_mapping[token_slice] if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) @@ -380,8 +382,8 @@ class FlashAttentionMetadataBuilder: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], + query_start_loc, + seq_lens, block_table_tensor, self.block_size, ) @@ -411,20 +413,20 @@ class FlashAttentionMetadataBuilder: use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + cu_prefix_query_lens = torch.tensor([0, num_tokens], dtype=torch.int32, device=self.runner.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, device=self.runner.device) - suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + suffix_kv_lens = (self.runner.seq_lens_np[req_slice] - common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, - max_query_len=num_actual_tokens, + max_query_len=num_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, causal=False) @@ -448,7 +450,7 @@ class FlashAttentionMetadataBuilder: causal=True) attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=num_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, @@ -466,6 +468,17 @@ class FlashAttentionMetadataBuilder: ) return attn_metadata + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + return self.build_slice( + max_query_len=max_query_len, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + req_slice=slice(0, num_reqs), + token_slice=slice(0, num_actual_tokens), + ) + def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 10a771e830b68..97f56898f7dcb 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -16,3 +16,11 @@ class CommonAttentionMetadata: seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + + +def slice_query_start_locs( + query_start_loc: torch.Tensor, + req_slice: slice, +) -> torch.Tensor: + return query_start_loc[req_slice.start: req_slice.stop + 1] -\ + query_start_loc[req_slice.start] \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 201796c96ee5c..34abd898d02bc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4,7 +4,7 @@ import copy import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, TypeAlias, Union import numpy as np import torch @@ -69,6 +69,11 @@ else: logger = init_logger(__name__) +AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], + AttnMetadataDict] + class GPUModelRunner(LoRAModelRunnerMixin): @@ -491,7 +496,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, + ubatch_slices: Optional[list[tuple[ + slice, slice]]] = None, # req_slice, token_slice + ) -> tuple[PerLayerAttnMetadata, torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -612,7 +619,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata: dict[str, FlashAttentionMetadata] = {} + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -629,15 +639,37 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.attn_metadata_builders[kv_cache_group_id], ) - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + if self.vllm_config.compilation_config.full_cuda_graph: + self.input_batch.block_table[kv_cache_group_id]\ + .slot_mapping.fill_(-1) + + if ubatch_slices is not None: + for ubid, (req_slice, token_slice) in enumerate(ubatch_slices): + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id]. + build_slice( + max(tokens[req_slice]), + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + req_slice=req_slice, + token_slice=token_slice, + )) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is dict + attn_metadata[layer_name] = attn_metadata_i use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -1057,10 +1089,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, + self, tokens_slice: slice, + intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: assert self.intermediate_tensors is not None + num_tokens = tokens_slice.stop - tokens_slice.start tp = self.vllm_config.parallel_config.tensor_parallel_size enabled_sp = self.vllm_config.compilation_config.pass_config. \ @@ -1072,49 +1106,37 @@ class GPUModelRunner(LoRAModelRunnerMixin): is_residual_scattered = tp > 1 and enabled_sp \ and num_tokens % tp == 0 + def copy_slice(is_scattered: bool) -> slice: + if is_scattered: + return slice(tokens_slice.start // tp, tokens_slice.stop // tp) + else: + return tokens_slice + # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. if sync_self: assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = "residual" and is_residual_scattered - copy_len = num_tokens // tp if is_scattered else \ - num_tokens - self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) + _copy_slice = copy_slice(is_scattered) + self.intermediate_tensors[k][_copy_slice].copy_( + v[_copy_slice], non_blocking=True) return IntermediateTensors({ k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] + v[copy_slice(k == "residual" and is_residual_scattered)] for k, v in self.intermediate_tensors.items() }) - @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output) - - # Prepare the decoder inputs. - attn_metadata, logits_indices, spec_decode_metadata = ( - self._prepare_inputs(scheduler_output)) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + def _get_model_inputs(self, tokens_slice: slice, + scheduler_output: "SchedulerOutput"): + num_tokens = tokens_slice.stop - tokens_slice.start if (self.use_cuda_graph - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + and num_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) + tokens_slice = \ + slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens)) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when @@ -1123,9 +1145,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.vllm_config.compilation_config.pass_config. \ enable_sequence_parallelism and tp_size > 1: from vllm.utils import round_up - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens + tokens_slice = slice(tokens_slice.start, + round_up(num_tokens, tp_size)) + num_tokens = tokens_slice.stop - tokens_slice.start # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1140,51 +1162,94 @@ class GPUModelRunner(LoRAModelRunnerMixin): # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_scheduled_tokens] + input_ids = self.input_ids[tokens_slice] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) - inputs_embeds = self.inputs_embeds[:num_input_tokens] + self.inputs_embeds[tokens_slice].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[tokens_slice] input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids[tokens_slice] inputs_embeds = None if self.uses_mrope: - positions = self.mrope_positions[:, :num_input_tokens] + positions = self.mrope_positions[:, tokens_slice] else: - positions = self.positions[:num_input_tokens] + positions = self.positions[tokens_slice] if get_pp_group().is_first_rank: intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + tokens_slice, intermediate_tensors, True) + return input_ids, positions, inputs_embeds, intermediate_tensors + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ubatch_slices: Optional[list[tuple[slice, slice]]] = None + + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output) + + # Prepare the decoder inputs. + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): - self.maybe_setup_kv_connector(scheduler_output) + self.maybe_setup_kv_connector(scheduler_output) - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + if ubatch_slices is not None: + for i, (_, tokens_slice) in enumerate(ubatch_slices): + input_ids, positions, inputs_embeds, intermediate_tensors = \ + self._get_model_inputs(tokens_slice, scheduler_output) + num_input_token = tokens_slice.stop - tokens_slice.start - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + with set_forward_context(attn_metadata[i], + self.vllm_config, + num_tokens=num_input_tokens): + + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + else: + input_ids, positions, inputs_embeds, intermediate_tensors = \ + self._get_model_inputs(slice(0, num_scheduled_tokens), + scheduler_output) + + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_scheduled_tokens): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py new file mode 100644 index 0000000000000..a08bf4b41d587 --- /dev/null +++ b/vllm/v1/worker/ubatching.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +class UBatchContext: + + def __init__(self, ubatch_id: int): + self.ubatch_id = ubatch_id