mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 05:27:02 +08:00
added splitting
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
1ba3ae80bf
commit
ee70ce0e4e
@ -60,6 +60,87 @@ class CommonAttentionMetadata:
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates a new query_start_loc that corresponds to the requests in
|
||||
request_slice.
|
||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||
This will break cudagraph compatibility.
|
||||
"""
|
||||
return query_start_loc[request_slice.start: request_slice.stop + 1] -\
|
||||
query_start_loc[request_slice.start]
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UbatchSlice,
|
||||
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
||||
request_slice)
|
||||
assert len(query_start_loc >= 2)
|
||||
query_start_loc_cpu = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc_cpu, request_slice)
|
||||
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
||||
request_slice]
|
||||
|
||||
num_requests = request_slice.stop - request_slice.start
|
||||
num_actual_tokens = token_slice.stop - token_slice.start
|
||||
max_query_len = int(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])).item())
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UbatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UbatchSlice in ubatch_slices.
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(
|
||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
return results
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user