diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9b69b28858fe3..837d7faf43708 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -179,7 +179,7 @@ class FlashAttentionImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." if attn_metadata is None: - # Dynamic shape profiling run. + # Profiling run. return output # IMPORTANT! @@ -193,13 +193,17 @@ class FlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. key_cache, value_cache = kv_cache.unbind(0) torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, - attn_metadata.slot_mapping[:num_actual_tokens], + attn_metadata.slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, @@ -208,17 +212,14 @@ class FlashAttentionImpl(AttentionImpl): # Compute attention and update output up to `num_actual_tokens`. if not attn_metadata.use_cascade: # Regular attention (common case). - batch_size = attn_metadata.block_table.shape[0] - - #TODO: Do we need to slice by [:batch_size+1]? flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, v=value_cache, out=output[:num_actual_tokens], - cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1], + cu_seqlens_q=attn_metadata.query_start_loc, max_seqlen_q=attn_metadata.max_query_len, - seqused_k=attn_metadata.seq_lens[:batch_size], + seqused_k=attn_metadata.seq_lens, max_seqlen_k=attn_metadata.max_seq_len, softmax_scale=self.scale, causal=True, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8bc0da8e3a953..ffbd1451a0a5c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -11,7 +11,7 @@ import torch.nn as nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -122,9 +122,6 @@ class GPUModelRunner: vocab_size=model_config.get_vocab_size(), ) - # self.use_cuda_graph = (self.vllm_config.compilation_config.level - # == CompilationLevel.PIECEWISE - # and not self.model_config.enforce_eager) self.use_cuda_graph = not self.model_config.enforce_eager # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -467,7 +464,8 @@ class GPUModelRunner: self.input_batch.block_table.get_device_tensor()[num_reqs:].fill_(-1) # Fill with -1s -- needed for reshape_and_cache - self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) # Definitely needed + self.slot_mapping[total_num_scheduled_tokens:].fill_( + -1) # Definitely needed # Prepare for cascade attention if needed. common_prefix_len = (scheduler_output.num_common_prefix_blocks * @@ -550,12 +548,12 @@ class GPUModelRunner: attn_metadata = FlashAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - query_start_loc=self.query_start_loc, + query_start_loc=self.query_start_loc[:num_reqs + 1], max_seq_len=max_seq_len, - seq_lens=self.seq_lens, + seq_lens=self.seq_lens[:num_reqs], block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), - slot_mapping=self.slot_mapping, + slot_mapping=self.slot_mapping[:total_num_scheduled_tokens], # Cascade stuff use_cascade=use_cascade, common_prefix_len=common_prefix_len, @@ -914,12 +912,12 @@ class GPUModelRunner: return FlashAttentionMetadata( num_actual_tokens=num_tokens, max_query_len=max_query_len, - query_start_loc=self.query_start_loc, + query_start_loc=self.query_start_loc[:num_reqs + 1], max_seq_len=max_seq_len, - seq_lens=self.seq_lens, + seq_lens=self.seq_lens[:num_reqs], block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), - slot_mapping=self.slot_mapping, + slot_mapping=self.slot_mapping[:max_seq_len], # Cascade stuff. Non-piecewise CUDA graphs NYI use_cascade=False, common_prefix_len=0,