From ee70ce0e4ebed6a46c8d54891ff45747324850f8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 25 Jul 2025 19:26:01 +0000 Subject: [PATCH] added splitting Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 81 +++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 43d3d1273c1ea..16b99aab842b8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -60,6 +60,87 @@ class CommonAttentionMetadata: slot_mapping: torch.Tensor +@dataclass +class UbatchSlice: + request_slice: slice + token_slice: slice + + +def slice_query_start_locs( + query_start_loc: torch.Tensor, + request_slice: slice, +) -> torch.Tensor: + """ + Creates a new query_start_loc that corresponds to the requests in + request_slice. + Note: This function creates a new tensor to hold the new query_start_locs. + This will break cudagraph compatibility. + """ + return query_start_loc[request_slice.start: request_slice.stop + 1] -\ + query_start_loc[request_slice.start] + + +def _make_metadata_with_slice( + ubatch_slice: UbatchSlice, + attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + """ + This function creates a new CommonAttentionMetadata that corresponds to + the requests included in ubatch_slice + """ + + request_slice = ubatch_slice.request_slice + token_slice = ubatch_slice.token_slice + + query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, + request_slice) + assert len(query_start_loc >= 2) + query_start_loc_cpu = slice_query_start_locs( + attn_metadata.query_start_loc_cpu, request_slice) + + seq_lens = attn_metadata.seq_lens[request_slice] + seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ + request_slice] + + num_requests = request_slice.stop - request_slice.start + num_actual_tokens = token_slice.stop - token_slice.start + max_query_len = int( + torch.max(torch.abs(query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1])).item()) + + block_table_tensor = attn_metadata.block_table_tensor[request_slice] + slot_mapping = attn_metadata.slot_mapping[token_slice] + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_requests, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + ) + + +def split_attn_metadata( + ubatch_slices: list[UbatchSlice], + common_attn_metadata: CommonAttentionMetadata, +) -> list[CommonAttentionMetadata]: + """ + Creates a new CommonAttentionMetadata instance that corresponds to the + requests for each UbatchSlice in ubatch_slices. + Note: This function does not modify common_attn_metadata + """ + results = [] + for ubatch_slice in ubatch_slices: + results.append( + _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + return results + + M = TypeVar("M")