mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 08:05:16 +08:00
[Attention] MLA move rotary embedding to cuda-graph region (#17668)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
760e3ecc8f
commit
5e6f939484
@ -211,8 +211,6 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
@ -377,7 +375,6 @@ class MLACommonState(AttentionState, Generic[T]):
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
@ -393,7 +390,6 @@ class MLACommonState(AttentionState, Generic[T]):
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
"input_positions": attn_metadata.decode_metadata.input_positions,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
@ -405,16 +401,10 @@ class MLACommonState(AttentionState, Generic[T]):
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_positions = attn_metadata.input_positions
|
||||
num_positions = input_positions.shape[0]
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# CUDA graph buffer is padded so only perform a partial copy based on
|
||||
# num_positions
|
||||
input_buffers["input_positions"][:num_positions].copy_(
|
||||
input_positions, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
@ -563,8 +548,6 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = self.__class__(
|
||||
# Required by ModelRunner
|
||||
@ -578,7 +561,6 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
# MLACommonMetadata
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
@ -615,8 +597,6 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = self.__class__(
|
||||
# Required by ModelRunner
|
||||
@ -646,7 +626,6 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim,
|
||||
is_profile_run=self.is_profile_run)
|
||||
return self._cached_decode_metadata
|
||||
@ -765,7 +744,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
@ -786,13 +764,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block, input_positions) in zip(
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
inter_data.input_positions):
|
||||
self.input_positions.extend(input_positions)
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
@ -912,8 +888,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
input_positions = async_tensor_h2d(self.input_positions, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
@ -987,7 +961,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
multi_modal_placeholder_index_maps=None, # Not Attention Related
|
||||
enable_kv_scales_calculation=False,
|
||||
# MLACommonMetadata
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
@ -1033,7 +1006,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
@ -1048,10 +1020,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.kv_b_proj = kv_b_proj
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
@ -1367,41 +1335,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
has_decode = attn_metadata.decode_metadata is not None
|
||||
has_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
assert hasattr(attn_metadata, "input_positions")
|
||||
|
||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
decode_q = q[num_prefill_tokens:]
|
||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||
decode_input_positions = \
|
||||
attn_metadata.input_positions[num_prefill_tokens:]
|
||||
|
||||
prefill_q = q[:num_prefill_tokens]
|
||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||
prefill_input_positions = \
|
||||
attn_metadata.input_positions[:num_prefill_tokens]
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
@ -1424,6 +1366,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
output[num_prefill_tokens:] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
@ -148,13 +148,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block, input_positions) in zip(
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
inter_data.input_positions):
|
||||
self.input_positions.extend(input_positions)
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
|
||||
@ -808,8 +808,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
if self.cos_sin_cache.device != positions.device:
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
@ -453,7 +453,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
)
|
||||
|
||||
@ -475,6 +474,13 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
|
||||
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
@ -204,7 +204,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
@ -269,9 +268,6 @@ class MLACommonPrefillMetadata:
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
@ -280,9 +276,6 @@ class MLACommonPrefillMetadata:
|
||||
|
||||
@dataclass
|
||||
class MLACommonDecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
@ -443,10 +436,8 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(self, input_positions: torch.Tensor,
|
||||
block_table: torch.Tensor, seq_lens: torch.Tensor):
|
||||
def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
@ -464,8 +455,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
@ -473,7 +462,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
@ -542,7 +530,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
prefill_metadata = MLACommonPrefillMetadata(
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
@ -552,7 +539,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
input_positions=input_positions[:self._num_decode_tokens],
|
||||
block_table=block_table[:self._num_decodes, ...],
|
||||
seq_lens=seq_lens[:self._num_decodes],
|
||||
)
|
||||
@ -599,7 +585,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
@ -614,15 +599,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
# Hack for V1 for now to avoid torch library overhead (since we are
|
||||
# already inside an attention custom op), pull out the forward
|
||||
# method from the rotary embedding and call it directly
|
||||
# TODO(lucas): we should probably find a cleaner way to do this
|
||||
self.rotary_emb = rotary_emb.forward_native
|
||||
if current_platform.is_cuda():
|
||||
self.rotary_emb = rotary_emb.forward_cuda
|
||||
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
@ -894,9 +870,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
@ -905,35 +878,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
decode_q = q[:num_decode_tokens]
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
@ -951,6 +901,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
attn_metadata)
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
@ -58,8 +58,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def _build_decode(self, input_positions: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
def _build_decode(self, block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
@ -69,7 +68,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
)
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user