add cuda graph support to triton_mla attention

This commit is contained in:
Alexander Matveev 2025-01-30 21:12:00 +00:00
parent 135c404fbb
commit 984ffddda6
5 changed files with 87 additions and 17 deletions

View File

@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
"""Context manager used when capturing CUDA graphs."""
yield

View File

@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
return self._decode_wrapper
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
assert positions is None
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),

View File

@ -90,33 +90,93 @@ class TritonMLAState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
assert positions is not None
self._positions = positions
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
return attn_metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
raise NotImplementedError(
"TritonMLAState does not support graph capture")
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input):
return

View File

@ -2,7 +2,7 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
import numpy as np
import torch
@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState):
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
assert positions is None
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens

View File

@ -1468,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype,
device=self.device)
with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context:
with self.attn_state.graph_capture(
max_batch_size, input_positions), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(