Bugfix for param_sink_key initialization, block_table for cascade and refactor forward in unified_attention_with_output

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-11-15 15:36:53 +08:00
parent 8de4315229
commit b0e880632a
3 changed files with 32 additions and 44 deletions

View File

@ -945,36 +945,27 @@ def unified_attention_with_output(
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> None: ) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name) attn_metadata, self, kv_cache = get_attention_context(layer_name)
if sink_key is None and sink_value is None: kwargs = {}
self.impl.forward( if sink_key is not None or sink_value is not None:
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
else:
assert sink_key is not None and sink_value is not None, ( assert sink_key is not None and sink_value is not None, (
"Currently, it is only supported when " "Currently, it is only supported when "
"sink_key and sink_value are both not None" "sink_key and sink_value are both not None"
) )
self.impl.forward( kwargs["sink_key"] = sink_key
self, kwargs["sink_value"] = sink_value
query,
key, self.impl.forward(
value, self,
kv_cache, query,
attn_metadata, key,
output=output, value,
output_scale=output_scale, kv_cache,
output_block_scale=output_block_scale, attn_metadata,
sink_key=sink_key, output=output,
sink_value=sink_value, output_scale=output_scale,
) output_block_scale=output_block_scale,
**kwargs,
)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(

View File

@ -854,15 +854,13 @@ class OpenPanguSinkAttention(nn.Module):
) )
else: else:
self.param_sink_value = torch.zeros( self.param_sink_value = torch.zeros(
torch.empty( (
( self.param_sink_number,
self.param_sink_number, self.num_kv_heads,
self.num_kv_heads, self.v_channels,
self.v_channels, ),
), device=torch.cuda.current_device(),
device=torch.cuda.current_device(), dtype=config.torch_dtype,
dtype=config.torch_dtype,
)
) )
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):

View File

@ -614,7 +614,8 @@ class FlashSinkAttentionImpl(AttentionImpl):
assert sink_len % block_size == 0 assert sink_len % block_size == 0
num_sink_blocks = sink_len // block_size num_sink_blocks = sink_len // block_size
sink_kv_slot_mapping = torch.arange( sink_kv_slot_mapping = torch.arange(
sink_len, block_size,
sink_len + block_size,
device=attn_metadata.slot_mapping.device, device=attn_metadata.slot_mapping.device,
dtype=attn_metadata.slot_mapping.dtype, dtype=attn_metadata.slot_mapping.dtype,
) )
@ -654,7 +655,10 @@ class FlashSinkAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata scheduler_metadata = attn_metadata.scheduler_metadata
sink_block_table = torch.arange( sink_block_table = torch.arange(
num_sink_blocks, device=block_table.device, dtype=block_table.dtype 1,
num_sink_blocks + 1,
device=block_table.device,
dtype=block_table.dtype,
) )
sink_block_table = sink_block_table[None, :].expand( sink_block_table = sink_block_table[None, :].expand(
block_table.shape[0], -1 block_table.shape[0], -1
@ -939,13 +943,8 @@ def cascade_attention(
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
num_sink_blocks = sink_len // block_size num_sink_blocks = sink_len // block_size
block_table = block_table + num_sink_blocks sink_block_table = torch.arange(
block_table[block_table == num_sink_blocks] = 0 1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype
sink_block_table = (
torch.arange(
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
)
+ 1
) )
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
block_table = torch.cat((sink_block_table, block_table), dim=1) block_table = torch.cat((sink_block_table, block_table), dim=1)