diff --git a/vllm/config.py b/vllm/config.py index fce8011be4015..6b65e36dcce60 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2305,8 +2305,8 @@ class CompilationConfig(BaseModel): backend: str = "" custom_ops: List[str] = Field(default_factory=list) splitting_ops: List[str] = Field(default_factory=lambda: [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", + # "vllm.unified_attention", + # "vllm.unified_attention_with_output", ]) use_inductor: bool = True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67166fb05085c..dc75183ca17b3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -118,6 +118,20 @@ class GPUModelRunner: dtype=self.dtype, device=self.device) + # Attention metadata related persistent buffers + self.query_start_loc = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.seq_start_loc = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.slot_mapping = torch.zeros( + self.max_num_tokens, + # CPU slot_mapping is int32, but + # this one must be int64 + dtype=torch.int64, + device=self.device) + # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np = np.arange(max(self.max_num_reqs, self.max_model_len), dtype=np.int32) @@ -337,27 +351,30 @@ class GPUModelRunner: self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( - self.device, non_blocking=True) - seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( - self.device, non_blocking=True) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( - self.device, non_blocking=True).long() + + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.seq_start_loc[:num_reqs + 1].copy_( + self.seq_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + self.slot_mapping_cpu[:total_num_scheduled_tokens], + non_blocking=True) + attn_metadata = FlashAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - query_start_loc=query_start_loc, + query_start_loc=self.query_start_loc, max_seq_len=max_seq_len, - seq_start_loc=seq_start_loc, + seq_start_loc=self.seq_start_loc, block_table=self.input_batch.block_table[:num_reqs], - slot_mapping=slot_mapping, + slot_mapping=self.slot_mapping, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled # token from the partial request. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + logits_indices = self.query_start_loc[1:num_reqs + 1] - 1 return attn_metadata, logits_indices def _prepare_sampling(