mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 10:07:04 +08:00
Bugfix for param_sink_key initialization, block_table for cascade and refactor forward in unified_attention_with_output
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
8de4315229
commit
b0e880632a
@ -945,36 +945,27 @@ def unified_attention_with_output(
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
if sink_key is None and sink_value is None:
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
else:
|
||||
kwargs = {}
|
||||
if sink_key is not None or sink_value is not None:
|
||||
assert sink_key is not None and sink_value is not None, (
|
||||
"Currently, it is only supported when "
|
||||
"sink_key and sink_value are both not None"
|
||||
)
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
sink_key=sink_key,
|
||||
sink_value=sink_value,
|
||||
)
|
||||
kwargs["sink_key"] = sink_key
|
||||
kwargs["sink_value"] = sink_value
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
|
||||
@ -854,15 +854,13 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.param_sink_value = torch.zeros(
|
||||
torch.empty(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.v_channels,
|
||||
),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.v_channels,
|
||||
),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
|
||||
@ -614,7 +614,8 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
assert sink_len % block_size == 0
|
||||
num_sink_blocks = sink_len // block_size
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
sink_len,
|
||||
block_size,
|
||||
sink_len + block_size,
|
||||
device=attn_metadata.slot_mapping.device,
|
||||
dtype=attn_metadata.slot_mapping.dtype,
|
||||
)
|
||||
@ -654,7 +655,10 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
sink_block_table = torch.arange(
|
||||
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
|
||||
1,
|
||||
num_sink_blocks + 1,
|
||||
device=block_table.device,
|
||||
dtype=block_table.dtype,
|
||||
)
|
||||
sink_block_table = sink_block_table[None, :].expand(
|
||||
block_table.shape[0], -1
|
||||
@ -939,13 +943,8 @@ def cascade_attention(
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
num_sink_blocks = sink_len // block_size
|
||||
block_table = block_table + num_sink_blocks
|
||||
block_table[block_table == num_sink_blocks] = 0
|
||||
sink_block_table = (
|
||||
torch.arange(
|
||||
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
|
||||
)
|
||||
+ 1
|
||||
sink_block_table = torch.arange(
|
||||
1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype
|
||||
)
|
||||
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
|
||||
block_table = torch.cat((sink_block_table, block_table), dim=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user