Bugfix for param_sink_key initialization, block_table for cascade and refactor forward in unified_attention_with_output

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-11-15 15:36:53 +08:00
parent 8de4315229
commit b0e880632a
3 changed files with 32 additions and 44 deletions

View File

@ -945,36 +945,27 @@ def unified_attention_with_output(
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
if sink_key is None and sink_value is None:
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
else:
kwargs = {}
if sink_key is not None or sink_value is not None:
assert sink_key is not None and sink_value is not None, (
"Currently, it is only supported when "
"sink_key and sink_value are both not None"
)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
sink_key=sink_key,
sink_value=sink_value,
)
kwargs["sink_key"] = sink_key
kwargs["sink_value"] = sink_value
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
**kwargs,
)
def unified_attention_with_output_fake(

View File

@ -854,15 +854,13 @@ class OpenPanguSinkAttention(nn.Module):
)
else:
self.param_sink_value = torch.zeros(
torch.empty(
(
self.param_sink_number,
self.num_kv_heads,
self.v_channels,
),
device=torch.cuda.current_device(),
dtype=config.torch_dtype,
)
(
self.param_sink_number,
self.num_kv_heads,
self.v_channels,
),
device=torch.cuda.current_device(),
dtype=config.torch_dtype,
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):

View File

@ -614,7 +614,8 @@ class FlashSinkAttentionImpl(AttentionImpl):
assert sink_len % block_size == 0
num_sink_blocks = sink_len // block_size
sink_kv_slot_mapping = torch.arange(
sink_len,
block_size,
sink_len + block_size,
device=attn_metadata.slot_mapping.device,
dtype=attn_metadata.slot_mapping.dtype,
)
@ -654,7 +655,10 @@ class FlashSinkAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
sink_block_table = torch.arange(
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
1,
num_sink_blocks + 1,
device=block_table.device,
dtype=block_table.dtype,
)
sink_block_table = sink_block_table[None, :].expand(
block_table.shape[0], -1
@ -939,13 +943,8 @@ def cascade_attention(
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
num_sink_blocks = sink_len // block_size
block_table = block_table + num_sink_blocks
block_table[block_table == num_sink_blocks] = 0
sink_block_table = (
torch.arange(
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
)
+ 1
sink_block_table = torch.arange(
1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype
)
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
block_table = torch.cat((sink_block_table, block_table), dim=1)