flash attn changes

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-01-02 13:26:21 -05:00
parent 0c7e6c1e36
commit 7eba374599

View File

@ -162,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,#[:num_actual_tokens],
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, k_scale,
v_scale, v_scale,
@ -174,9 +174,9 @@ class FlashAttentionImpl(AttentionImpl):
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
out=output[:num_actual_tokens], out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc, cu_seqlens_q=attn_metadata.query_start_loc[:num_actual_tokens],
max_seqlen_q=attn_metadata.max_query_len, max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc, cu_seqlens_k=attn_metadata.seq_start_loc[:num_actual_tokens],
max_seqlen_k=attn_metadata.max_seq_len, max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,