mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 11:12:20 +08:00
add cuda graph support to triton_mla attention
This commit is contained in:
parent
135c404fbb
commit
984ffddda6
@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
"""Context manager used when capturing CUDA graphs."""
|
||||
yield
|
||||
|
||||
|
||||
@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
|
||||
return self._decode_wrapper
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
assert positions is None
|
||||
|
||||
self._is_graph_capturing = True
|
||||
self._graph_decode_wrapper = None
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
|
||||
@ -90,33 +90,93 @@ class TritonMLAState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support graph capture")
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
assert positions is not None
|
||||
self._positions = positions
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support graph capture")
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support graph capture")
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support graph capture")
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support graph capture")
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
return
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState):
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
|
||||
assert positions is None
|
||||
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState):
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
|
||||
@ -1468,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
with self.attn_state.graph_capture(max_batch_size), graph_capture(
|
||||
self.device) as graph_capture_context:
|
||||
with self.attn_state.graph_capture(
|
||||
max_batch_size, input_positions), graph_capture(
|
||||
self.device) as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user