Fill TorchSDPAAttentionMetadata seq_lens_field for prefill (#10799)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2024-12-01 23:05:32 -03:00 committed by GitHub
parent 073a4bd1c0
commit e25810ae29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -341,7 +341,11 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
)
# For multi-modal models
placeholder_index_maps = None