add cuda graph support to triton_mla attention

This commit is contained in:
Alexander Matveev 2025-01-30 21:12:00 +00:00
parent 135c404fbb
commit 984ffddda6
5 changed files with 87 additions and 17 deletions

View File

@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
@abstractmethod @abstractmethod
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
"""Context manager used when capturing CUDA graphs.""" """Context manager used when capturing CUDA graphs."""
yield yield

View File

@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
return self._decode_wrapper return self._decode_wrapper
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
assert positions is None
self._is_graph_capturing = True self._is_graph_capturing = True
self._graph_decode_wrapper = None self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ), self._graph_slot_mapping = torch.full((max_batch_size, ),

View File

@ -90,33 +90,93 @@ class TritonMLAState(AttentionState):
def __init__(self, runner): def __init__(self, runner):
self.runner = runner self.runner = runner
self._is_graph_capturing = False
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int,
raise NotImplementedError( positions: Optional[torch.Tensor]):
"TritonMLAState does not support graph capture") self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
assert positions is not None
self._positions = positions
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int): def graph_clone(self, batch_size: int):
raise NotImplementedError( assert self._is_graph_capturing
"TritonMLAState does not support graph capture") return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch( def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False): self, batch_size: int, is_encoder_decoder_model: bool = False):
raise NotImplementedError( assert self._is_graph_capturing
"TritonMLAState does not support graph capture")
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
return attn_metadata
def get_graph_input_buffers(self, def get_graph_input_buffers(self,
attn_metadata, attn_metadata,
is_encoder_decoder_model: bool = False): is_encoder_decoder_model: bool = False):
raise NotImplementedError( input_buffers = {
"TritonMLAState does not support graph capture") "slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
return input_buffers
def prepare_graph_input_buffers(self, def prepare_graph_input_buffers(self,
input_buffers, input_buffers,
attn_metadata, attn_metadata,
is_encoder_decoder_model: bool = False): is_encoder_decoder_model: bool = False):
raise NotImplementedError( input_buffers["seq_lens_tensor"].copy_(
"TritonMLAState does not support graph capture") attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input): def begin_forward(self, model_input):
return return

View File

@ -2,7 +2,7 @@
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from itertools import accumulate from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
import numpy as np import numpy as np
import torch import torch
@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState):
self._is_graph_capturing = False self._is_graph_capturing = False
@contextmanager @contextmanager
def graph_capture(self, max_batch_size: int): def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
assert positions is None
self._is_graph_capturing = True self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ), self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID, PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState):
device=self.runner.device) device=self.runner.device)
self._graph_block_tables = torch.from_numpy( self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device) self.runner.graph_block_tables).to(device=self.runner.device)
yield yield
self._is_graph_capturing = False self._is_graph_capturing = False
del self._graph_slot_mapping del self._graph_slot_mapping
del self._graph_seq_lens del self._graph_seq_lens

View File

@ -1468,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
with self.attn_state.graph_capture(max_batch_size), graph_capture( with self.attn_state.graph_capture(
self.device) as graph_capture_context: max_batch_size, input_positions), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(