diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a82f1bb20fba3..527b31153410b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -266,7 +266,7 @@ class FlashAttentionMetadataBuilder( use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0,num_actual_tokens], + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, device=self.runner.device) prefix_kv_lens = torch.tensor([common_prefix_len],