mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-27 14:07:03 +08:00
Revert "fix ubatch datatype issue"
This reverts commit 9e16220e4e8a736f26ea93e355fe820de9c58264, reversing changes made to 5215c80a4988e81d2f5971e02d50d3785cab5ae8.
This commit is contained in:
parent
143b09e6be
commit
a0a11bc0b5
@ -1,157 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
|
||||||
from tests.v1.attention.utils import create_common_attn_metadata
|
|
||||||
from vllm.v1.attention.backends.utils import (UbatchSlice,
|
|
||||||
_make_metadata_with_slice,
|
|
||||||
slice_query_start_locs,
|
|
||||||
split_attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_query_start_loc():
|
|
||||||
"""Sample query_start_loc tensor for testing"""
|
|
||||||
return torch.tensor([0, 5, 12, 20, 35, 50])
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_slice_middle(sample_query_start_loc):
|
|
||||||
"""Test slicing from middle of tensor"""
|
|
||||||
req_slice = slice(1, 3) # slice from index 1 to 3
|
|
||||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
|
||||||
|
|
||||||
expected = torch.tensor([0, 7, 15])
|
|
||||||
assert torch.equal(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_slice_from_beginning(sample_query_start_loc):
|
|
||||||
"""Test slicing from the beginning of tensor"""
|
|
||||||
req_slice = slice(0, 2) # slice from index 0 to 2
|
|
||||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
|
||||||
|
|
||||||
expected = torch.tensor([0, 5, 12])
|
|
||||||
assert torch.equal(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_slice_to_end(sample_query_start_loc):
|
|
||||||
"""Test slicing to the end of tensor"""
|
|
||||||
req_slice = slice(3, 5) # slice from index 3 to 5 (last index)
|
|
||||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
|
||||||
|
|
||||||
expected = torch.tensor([0, 15, 30])
|
|
||||||
assert torch.equal(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_element_slice(sample_query_start_loc):
|
|
||||||
"""Test slice that results in single element"""
|
|
||||||
req_slice = slice(2, 3) # slice from index 2 to 3
|
|
||||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
|
||||||
|
|
||||||
expected = torch.tensor([0, 8])
|
|
||||||
assert torch.equal(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_tensor_slice(sample_query_start_loc):
|
|
||||||
"""Test slicing the entire tensor"""
|
|
||||||
req_slice = slice(0, 5) # slice entire tensor
|
|
||||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
|
||||||
|
|
||||||
expected = torch.tensor([0, 5, 12, 20, 35, 50])
|
|
||||||
assert torch.equal(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_slice_bounds_edge_cases(sample_query_start_loc):
|
|
||||||
# Test slice that goes exactly to the last element
|
|
||||||
req_slice = slice(4, 5) # Last index
|
|
||||||
result = slice_query_start_locs(sample_query_start_loc, req_slice)
|
|
||||||
|
|
||||||
expected = torch.tensor([0, 15])
|
|
||||||
assert torch.equal(result, expected)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def small_decode_metadata():
|
|
||||||
"""Create metadata for small decode batch"""
|
|
||||||
batch_spec = BATCH_SPECS["small_decode"]
|
|
||||||
device = torch.device("cpu")
|
|
||||||
return create_common_attn_metadata(batch_spec,
|
|
||||||
block_size=16,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def large_decode_metadata():
|
|
||||||
"""Create metadata for small decode batch"""
|
|
||||||
batch_spec = BATCH_SPECS["large_decode"]
|
|
||||||
device = torch.device("cpu")
|
|
||||||
return create_common_attn_metadata(batch_spec,
|
|
||||||
block_size=16,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mixed_small_metadata():
|
|
||||||
"""Create metadata for mixed small batch"""
|
|
||||||
batch_spec = BATCH_SPECS["mixed_small"]
|
|
||||||
device = torch.device("cpu")
|
|
||||||
return create_common_attn_metadata(batch_spec,
|
|
||||||
block_size=16,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
|
|
||||||
# Tests for _make_metadata_with_slice
|
|
||||||
def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
|
||||||
"""Test slicing decode batch metadata"""
|
|
||||||
# Split first request only
|
|
||||||
ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1))
|
|
||||||
|
|
||||||
result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata)
|
|
||||||
|
|
||||||
# Check sliced results
|
|
||||||
assert result.num_reqs == 1 # slice(0, 1) gives 1 requests
|
|
||||||
assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token
|
|
||||||
assert result.max_query_len == 1
|
|
||||||
assert torch.equal(result.query_start_loc, torch.tensor([0, 1]))
|
|
||||||
assert torch.equal(result.seq_lens, torch.tensor([32]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
|
||||||
"""Test slicing mixed batch metadata"""
|
|
||||||
ubatch_slice = UbatchSlice(slice(1, 3),
|
|
||||||
slice(1, 7)) # Requests 1-3, tokens 1-7
|
|
||||||
|
|
||||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
|
||||||
|
|
||||||
assert result.num_reqs == 2 # slice(1, 3) gives 2 requests
|
|
||||||
assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens
|
|
||||||
assert result.max_query_len == 5
|
|
||||||
assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6]))
|
|
||||||
assert torch.equal(result.seq_lens, torch.tensor([40, 48]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
|
||||||
"""Test splitting decode batch into two equal parts"""
|
|
||||||
num_tokens = large_decode_metadata.num_reqs
|
|
||||||
mid_point = num_tokens // 2
|
|
||||||
ubatch_slices = [
|
|
||||||
UbatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
|
||||||
UbatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
|
||||||
num_tokens)),
|
|
||||||
]
|
|
||||||
|
|
||||||
results = split_attn_metadata(ubatch_slices, large_decode_metadata)
|
|
||||||
|
|
||||||
assert len(results) == 2
|
|
||||||
|
|
||||||
# Check first split
|
|
||||||
assert results[0].num_reqs == mid_point
|
|
||||||
assert results[0].num_actual_tokens == mid_point
|
|
||||||
assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point))
|
|
||||||
|
|
||||||
# Check second split
|
|
||||||
assert results[1].num_reqs == mid_point
|
|
||||||
assert results[1].num_actual_tokens == mid_point
|
|
||||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
|
||||||
@ -76,7 +76,6 @@ def slice_query_start_locs(
|
|||||||
"""
|
"""
|
||||||
Creates a new query_start_loc that corresponds to the requests in
|
Creates a new query_start_loc that corresponds to the requests in
|
||||||
request_slice.
|
request_slice.
|
||||||
|
|
||||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||||
This will break cudagraph compatibility.
|
This will break cudagraph compatibility.
|
||||||
"""
|
"""
|
||||||
@ -130,19 +129,19 @@ def _make_metadata_with_slice(
|
|||||||
|
|
||||||
|
|
||||||
def split_attn_metadata(
|
def split_attn_metadata(
|
||||||
ubatch_slices: list[UbatchSlice],
|
ubatch_slices: list[tuple[slice, slice]],
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
) -> list[CommonAttentionMetadata]:
|
) -> list[CommonAttentionMetadata]:
|
||||||
"""
|
"""
|
||||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||||
requests for each UbatchSlice in ubatch_slices.
|
requests for each UbatchSlice in ubatch_slices.
|
||||||
|
|
||||||
Note: This function does not modify common_attn_metadata
|
Note: This function does not modify common_attn_metadata
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
for ubatch_slice in ubatch_slices:
|
for ubatch_slice in ubatch_slices:
|
||||||
results.append(
|
s = UbatchSlice(request_slice=ubatch_slice[0],
|
||||||
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
token_slice=ubatch_slice[1])
|
||||||
|
results.append(_make_metadata_with_slice(s, common_attn_metadata))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -52,7 +52,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
|||||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
UbatchSlice, make_kv_sharing_fast_prefill_attention_metadata,
|
make_kv_sharing_fast_prefill_attention_metadata,
|
||||||
make_local_attention_virtual_batches, split_attn_metadata)
|
make_local_attention_virtual_batches, split_attn_metadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
@ -100,6 +100,7 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
|
|||||||
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||||
AttnMetadataDict]
|
AttnMetadataDict]
|
||||||
|
|
||||||
|
UbatchSlice: TypeAlias = tuple[slice, slice]
|
||||||
UBatchSlices: TypeAlias = list[UbatchSlice]
|
UBatchSlices: TypeAlias = list[UbatchSlice]
|
||||||
|
|
||||||
|
|
||||||
@ -655,9 +656,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
assert b0_reqs_end < num_reqs and \
|
assert b0_reqs_end < num_reqs and \
|
||||||
b0_tokens_end < total_num_scheduled_tokens
|
b0_tokens_end < total_num_scheduled_tokens
|
||||||
ubatch_slices = [
|
ubatch_slices = [
|
||||||
UbatchSlice(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||||
UbatchSlice(slice(b0_reqs_end, num_reqs),
|
(slice(b0_reqs_end,
|
||||||
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
num_reqs), slice(b0_tokens_end,
|
||||||
|
total_num_scheduled_tokens)),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Compute ubatch padding. This currently only accounts for DP padding
|
# Compute ubatch padding. This currently only accounts for DP padding
|
||||||
@ -1593,10 +1595,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
first_ubatch_slice = ubatch_slices[0]
|
first_ubatch_slice = ubatch_slices[0]
|
||||||
second_ubatch_slice = ubatch_slices[1]
|
second_ubatch_slice = ubatch_slices[1]
|
||||||
|
|
||||||
first_ubatch_num_tokens = first_ubatch_slice.token_slice.stop - \
|
first_ubatch_num_tokens = first_ubatch_slice[
|
||||||
first_ubatch_slice.token_slice.start
|
1].stop - first_ubatch_slice[1].start
|
||||||
second_ubatch_num_tokens = second_ubatch_slice.token_slice.stop - \
|
second_ubatch_num_tokens = second_ubatch_slice[
|
||||||
second_ubatch_slice.token_slice.start
|
1].stop - second_ubatch_slice[1].start
|
||||||
# We don't support prefills yet so the two ubatches should only differ
|
# We don't support prefills yet so the two ubatches should only differ
|
||||||
# by at most one token
|
# by at most one token
|
||||||
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
|
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
|
||||||
@ -1633,7 +1635,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# slicing but before attention meta data creation
|
# slicing but before attention meta data creation
|
||||||
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
|
def pad_out_ubatch_first_stage(self, ubatch_slices: UBatchSlices,
|
||||||
num_pad_tokens: int):
|
num_pad_tokens: int):
|
||||||
original_num_tokens = ubatch_slices[1].token_slice.stop
|
original_num_tokens = ubatch_slices[1][1].stop
|
||||||
assert num_pad_tokens < original_num_tokens
|
assert num_pad_tokens < original_num_tokens
|
||||||
total_num_tokens_per_ubatch = (original_num_tokens +
|
total_num_tokens_per_ubatch = (original_num_tokens +
|
||||||
num_pad_tokens) // 2
|
num_pad_tokens) // 2
|
||||||
@ -1641,10 +1643,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
||||||
original_num_tokens)
|
original_num_tokens)
|
||||||
|
|
||||||
ubatch_slices[0] = UbatchSlice(padded_first_ubatch_slice,
|
ubatch_slices[0] = (padded_first_ubatch_slice,
|
||||||
padded_first_ubatch_slice)
|
padded_first_ubatch_slice)
|
||||||
ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice,
|
ubatch_slices[1] = (padded_second_ubatch_slice,
|
||||||
padded_second_ubatch_slice)
|
padded_second_ubatch_slice)
|
||||||
|
|
||||||
# This is where the second ubatch is adjusted to account for the padding.
|
# This is where the second ubatch is adjusted to account for the padding.
|
||||||
# Should be called after attention metadata creation. This just pads
|
# Should be called after attention metadata creation. This just pads
|
||||||
@ -1653,10 +1655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
|
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
|
||||||
num_total_tokens: int):
|
num_total_tokens: int):
|
||||||
# TODO Add asserts to make sure stage one ran
|
# TODO Add asserts to make sure stage one ran
|
||||||
padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start,
|
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start,
|
||||||
num_total_tokens)
|
num_total_tokens)
|
||||||
ubatch_slices[1] = UbatchSlice(padded_second_ubatch_slice,
|
ubatch_slices[1] = (padded_second_ubatch_slice,
|
||||||
padded_second_ubatch_slice)
|
padded_second_ubatch_slice)
|
||||||
|
|
||||||
def should_ubatch(self, should_ubatch: bool) -> bool:
|
def should_ubatch(self, should_ubatch: bool) -> bool:
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
@ -1751,9 +1753,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Create one forward context per ubatch
|
# Create one forward context per ubatch
|
||||||
forward_contexts = []
|
forward_contexts = []
|
||||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
num_tokens = (ubatch_slice.token_slice.stop -
|
num_tokens = (tokens_slice.stop - tokens_slice.start)
|
||||||
ubatch_slice.token_slice.start)
|
|
||||||
forward_contexts.append(
|
forward_contexts.append(
|
||||||
create_forward_context(
|
create_forward_context(
|
||||||
attn_metadata[i] if attn_metadata is not None else None,
|
attn_metadata[i] if attn_metadata is not None else None,
|
||||||
@ -1771,18 +1772,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
enable_async_comms=self.parallel_config.enable_async_comms)
|
enable_async_comms=self.parallel_config.enable_async_comms)
|
||||||
|
|
||||||
ubatch_metadata: list[UbatchMetadata] = []
|
ubatch_metadata: list[UbatchMetadata] = []
|
||||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self.model_inputs(
|
self.model_inputs(tokens_slice, scheduler_output, is_dummy_run)
|
||||||
ubatch_slice.token_slice, scheduler_output, is_dummy_run)
|
|
||||||
ubatch_metadata.append(
|
ubatch_metadata.append(
|
||||||
UbatchMetadata(context=ubatch_ctxs[i],
|
UbatchMetadata(context=ubatch_ctxs[i],
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
num_tokens=ubatch_slice.token_slice.stop -
|
num_tokens=tokens_slice.stop -
|
||||||
ubatch_slice.token_slice.start))
|
tokens_slice.start))
|
||||||
|
|
||||||
return ubatch_metadata
|
return ubatch_metadata
|
||||||
|
|
||||||
@ -1808,8 +1808,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
results: list[tuple[int, torch.Tensor]] = []
|
results: list[tuple[int, torch.Tensor]] = []
|
||||||
compute_stream = ubatch_metadata[0].context.compute_stream
|
compute_stream = ubatch_metadata[0].context.compute_stream
|
||||||
num_tokens = ubatch_metadata[0].num_tokens + \
|
num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[
|
||||||
ubatch_metadata[1].num_tokens
|
1].num_tokens
|
||||||
|
|
||||||
# Ubatches will manually manage the forward context, so we override
|
# Ubatches will manually manage the forward context, so we override
|
||||||
# it to None here so we can have it restored correctly later
|
# it to None here so we can have it restored correctly later
|
||||||
@ -2704,12 +2704,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dp_size,
|
dp_size,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
ubatch_slices = [
|
ubatch_slices = [(slice(0,
|
||||||
UbatchSlice(slice(0, num_reqs // 2), slice(0,
|
num_reqs // 2), slice(0, num_tokens // 2)),
|
||||||
num_tokens // 2)),
|
(slice(num_reqs // 2, num_reqs),
|
||||||
UbatchSlice(slice(num_reqs // 2, num_reqs),
|
slice(num_tokens // 2, num_tokens))]
|
||||||
slice(num_tokens // 2, num_tokens))
|
|
||||||
]
|
|
||||||
|
|
||||||
# attn_metadata: Optional[dict[str, Any]] = None
|
# attn_metadata: Optional[dict[str, Any]] = None
|
||||||
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user