From 984ffddda650b4cabf87c73806e52407ede6b620 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 30 Jan 2025 21:12:00 +0000 Subject: [PATCH] add cuda graph support to triton_mla attention --- vllm/attention/backends/abstract.py | 3 +- vllm/attention/backends/flashinfer.py | 5 +- vllm/attention/backends/triton_mla.py | 82 +++++++++++++++++++++++---- vllm/attention/backends/utils.py | 9 ++- vllm/worker/model_runner.py | 5 +- 5 files changed, 87 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 75885947edda9..d05d1ebaac414 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]): @abstractmethod @contextmanager - def graph_capture(self, max_batch_size: int): + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): """Context manager used when capturing CUDA graphs.""" yield diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7cccef9608218..d50a727c0de4b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -213,7 +213,10 @@ class FlashInferState(AttentionState): return self._decode_wrapper @contextmanager - def graph_capture(self, max_batch_size: int): + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): + assert positions is None + self._is_graph_capturing = True self._graph_decode_wrapper = None self._graph_slot_mapping = torch.full((max_batch_size, ), diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 3514b18df2d6d..21a6826c4e5e6 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -90,33 +90,93 @@ class TritonMLAState(AttentionState): def __init__(self, runner): self.runner = runner + self._is_graph_capturing = False @contextmanager - def graph_capture(self, max_batch_size: int): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + def graph_capture(self, max_batch_size: int, + positions: Optional[torch.Tensor]): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + assert positions is not None + self._positions = positions + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables def graph_clone(self, batch_size: int): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + assert self._is_graph_capturing + return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return attn_metadata def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return input_buffers def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - raise NotImplementedError( - "TritonMLAState does not support graph capture") + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") def begin_forward(self, model_input): return diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 84fe89b7df360..6949991d2f5d8 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -2,7 +2,7 @@ from collections import defaultdict from contextlib import contextmanager from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional) import numpy as np import torch @@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState): self._is_graph_capturing = False @contextmanager - def graph_capture(self, max_batch_size: int): + def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]): + assert positions is None + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), PAD_SLOT_ID, dtype=torch.long, @@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState): device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False del self._graph_slot_mapping del self._graph_seq_lens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b6ed3abab4247..0f654576aa67d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1468,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): dtype=self.model_config.dtype, device=self.device) - with self.attn_state.graph_capture(max_batch_size), graph_capture( - self.device) as graph_capture_context: + with self.attn_state.graph_capture( + max_batch_size, input_positions), graph_capture( + self.device) as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range(