mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 23:50:04 +08:00
Bugfix for param_sink_key initialization, block_table for cascade and refactor forward in unified_attention_with_output
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
8de4315229
commit
b0e880632a
@ -945,36 +945,27 @@ def unified_attention_with_output(
|
|||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||||
if sink_key is None and sink_value is None:
|
kwargs = {}
|
||||||
self.impl.forward(
|
if sink_key is not None or sink_value is not None:
|
||||||
self,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
output=output,
|
|
||||||
output_scale=output_scale,
|
|
||||||
output_block_scale=output_block_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert sink_key is not None and sink_value is not None, (
|
assert sink_key is not None and sink_value is not None, (
|
||||||
"Currently, it is only supported when "
|
"Currently, it is only supported when "
|
||||||
"sink_key and sink_value are both not None"
|
"sink_key and sink_value are both not None"
|
||||||
)
|
)
|
||||||
self.impl.forward(
|
kwargs["sink_key"] = sink_key
|
||||||
self,
|
kwargs["sink_value"] = sink_value
|
||||||
query,
|
|
||||||
key,
|
self.impl.forward(
|
||||||
value,
|
self,
|
||||||
kv_cache,
|
query,
|
||||||
attn_metadata,
|
key,
|
||||||
output=output,
|
value,
|
||||||
output_scale=output_scale,
|
kv_cache,
|
||||||
output_block_scale=output_block_scale,
|
attn_metadata,
|
||||||
sink_key=sink_key,
|
output=output,
|
||||||
sink_value=sink_value,
|
output_scale=output_scale,
|
||||||
)
|
output_block_scale=output_block_scale,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_with_output_fake(
|
def unified_attention_with_output_fake(
|
||||||
|
|||||||
@ -854,15 +854,13 @@ class OpenPanguSinkAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.param_sink_value = torch.zeros(
|
self.param_sink_value = torch.zeros(
|
||||||
torch.empty(
|
(
|
||||||
(
|
self.param_sink_number,
|
||||||
self.param_sink_number,
|
self.num_kv_heads,
|
||||||
self.num_kv_heads,
|
self.v_channels,
|
||||||
self.v_channels,
|
),
|
||||||
),
|
device=torch.cuda.current_device(),
|
||||||
device=torch.cuda.current_device(),
|
dtype=config.torch_dtype,
|
||||||
dtype=config.torch_dtype,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
|||||||
@ -614,7 +614,8 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
|||||||
assert sink_len % block_size == 0
|
assert sink_len % block_size == 0
|
||||||
num_sink_blocks = sink_len // block_size
|
num_sink_blocks = sink_len // block_size
|
||||||
sink_kv_slot_mapping = torch.arange(
|
sink_kv_slot_mapping = torch.arange(
|
||||||
sink_len,
|
block_size,
|
||||||
|
sink_len + block_size,
|
||||||
device=attn_metadata.slot_mapping.device,
|
device=attn_metadata.slot_mapping.device,
|
||||||
dtype=attn_metadata.slot_mapping.dtype,
|
dtype=attn_metadata.slot_mapping.dtype,
|
||||||
)
|
)
|
||||||
@ -654,7 +655,10 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
|||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||||
sink_block_table = torch.arange(
|
sink_block_table = torch.arange(
|
||||||
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
|
1,
|
||||||
|
num_sink_blocks + 1,
|
||||||
|
device=block_table.device,
|
||||||
|
dtype=block_table.dtype,
|
||||||
)
|
)
|
||||||
sink_block_table = sink_block_table[None, :].expand(
|
sink_block_table = sink_block_table[None, :].expand(
|
||||||
block_table.shape[0], -1
|
block_table.shape[0], -1
|
||||||
@ -939,13 +943,8 @@ def cascade_attention(
|
|||||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||||
|
|
||||||
num_sink_blocks = sink_len // block_size
|
num_sink_blocks = sink_len // block_size
|
||||||
block_table = block_table + num_sink_blocks
|
sink_block_table = torch.arange(
|
||||||
block_table[block_table == num_sink_blocks] = 0
|
1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype
|
||||||
sink_block_table = (
|
|
||||||
torch.arange(
|
|
||||||
num_sink_blocks, device=block_table.device, dtype=block_table.dtype
|
|
||||||
)
|
|
||||||
+ 1
|
|
||||||
)
|
)
|
||||||
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
|
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
|
||||||
block_table = torch.cat((sink_block_table, block_table), dim=1)
|
block_table = torch.cat((sink_block_table, block_table), dim=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user