mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 00:31:19 +08:00
add cuda graph support to triton_mla attention
This commit is contained in:
parent
135c404fbb
commit
984ffddda6
@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture(self, max_batch_size: int):
|
def graph_capture(self, max_batch_size: int,
|
||||||
|
positions: Optional[torch.Tensor]):
|
||||||
"""Context manager used when capturing CUDA graphs."""
|
"""Context manager used when capturing CUDA graphs."""
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
|
|||||||
return self._decode_wrapper
|
return self._decode_wrapper
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture(self, max_batch_size: int):
|
def graph_capture(self, max_batch_size: int,
|
||||||
|
positions: Optional[torch.Tensor]):
|
||||||
|
assert positions is None
|
||||||
|
|
||||||
self._is_graph_capturing = True
|
self._is_graph_capturing = True
|
||||||
self._graph_decode_wrapper = None
|
self._graph_decode_wrapper = None
|
||||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||||
|
|||||||
@ -90,33 +90,93 @@ class TritonMLAState(AttentionState):
|
|||||||
|
|
||||||
def __init__(self, runner):
|
def __init__(self, runner):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self._is_graph_capturing = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture(self, max_batch_size: int):
|
def graph_capture(self, max_batch_size: int,
|
||||||
raise NotImplementedError(
|
positions: Optional[torch.Tensor]):
|
||||||
"TritonMLAState does not support graph capture")
|
self._is_graph_capturing = True
|
||||||
|
|
||||||
|
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.runner.device)
|
||||||
|
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.runner.device)
|
||||||
|
self._graph_block_tables = torch.from_numpy(
|
||||||
|
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||||
|
|
||||||
|
assert positions is not None
|
||||||
|
self._positions = positions
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
self._is_graph_capturing = False
|
||||||
|
del self._graph_slot_mapping
|
||||||
|
del self._graph_seq_lens
|
||||||
|
del self._graph_block_tables
|
||||||
|
|
||||||
def graph_clone(self, batch_size: int):
|
def graph_clone(self, batch_size: int):
|
||||||
raise NotImplementedError(
|
assert self._is_graph_capturing
|
||||||
"TritonMLAState does not support graph capture")
|
return self.__class__(self.runner)
|
||||||
|
|
||||||
def graph_capture_get_metadata_for_batch(
|
def graph_capture_get_metadata_for_batch(
|
||||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||||
raise NotImplementedError(
|
assert self._is_graph_capturing
|
||||||
"TritonMLAState does not support graph capture")
|
|
||||||
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=batch_size,
|
||||||
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
|
seq_lens=None,
|
||||||
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||||
|
max_query_len=1,
|
||||||
|
max_decode_query_len=1,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||||
|
query_start_loc=None,
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=None,
|
||||||
|
block_tables=self._graph_block_tables[:batch_size],
|
||||||
|
use_cuda_graph=True,
|
||||||
|
input_positions=self._positions[:batch_size],
|
||||||
|
head_dim=self.runner.model_config.get_head_size())
|
||||||
|
|
||||||
|
if is_encoder_decoder_model:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TritonMLAState does not support encoder/decoder yet")
|
||||||
|
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
def get_graph_input_buffers(self,
|
def get_graph_input_buffers(self,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_encoder_decoder_model: bool = False):
|
is_encoder_decoder_model: bool = False):
|
||||||
raise NotImplementedError(
|
input_buffers = {
|
||||||
"TritonMLAState does not support graph capture")
|
"slot_mapping": attn_metadata.slot_mapping,
|
||||||
|
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||||
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||||
|
}
|
||||||
|
if is_encoder_decoder_model:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TritonMLAState does not support encoder/decoder yet")
|
||||||
|
|
||||||
|
return input_buffers
|
||||||
|
|
||||||
def prepare_graph_input_buffers(self,
|
def prepare_graph_input_buffers(self,
|
||||||
input_buffers,
|
input_buffers,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
is_encoder_decoder_model: bool = False):
|
is_encoder_decoder_model: bool = False):
|
||||||
raise NotImplementedError(
|
input_buffers["seq_lens_tensor"].copy_(
|
||||||
"TritonMLAState does not support graph capture")
|
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||||
|
input_buffers["block_tables"].copy_(
|
||||||
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||||
|
if is_encoder_decoder_model:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TritonMLAState does not support encoder/decoder yet")
|
||||||
|
|
||||||
def begin_forward(self, model_input):
|
def begin_forward(self, model_input):
|
||||||
return
|
return
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState):
|
|||||||
self._is_graph_capturing = False
|
self._is_graph_capturing = False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def graph_capture(self, max_batch_size: int):
|
def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
|
||||||
|
assert positions is None
|
||||||
|
|
||||||
self._is_graph_capturing = True
|
self._is_graph_capturing = True
|
||||||
|
|
||||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||||
PAD_SLOT_ID,
|
PAD_SLOT_ID,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState):
|
|||||||
device=self.runner.device)
|
device=self.runner.device)
|
||||||
self._graph_block_tables = torch.from_numpy(
|
self._graph_block_tables = torch.from_numpy(
|
||||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
self._is_graph_capturing = False
|
self._is_graph_capturing = False
|
||||||
del self._graph_slot_mapping
|
del self._graph_slot_mapping
|
||||||
del self._graph_seq_lens
|
del self._graph_seq_lens
|
||||||
|
|||||||
@ -1468,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
with self.attn_state.graph_capture(max_batch_size), graph_capture(
|
with self.attn_state.graph_capture(
|
||||||
self.device) as graph_capture_context:
|
max_batch_size, input_positions), graph_capture(
|
||||||
|
self.device) as graph_capture_context:
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
for virtual_engine in range(
|
for virtual_engine in range(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user