mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 23:26:03 +08:00
[Bugfix][TPU] Fix tpu model runner testcase failure (#18810)
Signed-off-by: Carol Zheng <cazheng@google.com>
This commit is contained in:
parent
4577fc9abb
commit
fba02e3bd1
@ -81,7 +81,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
mm_hashes=[],
|
mm_hashes=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
block_ids=[0],
|
block_ids=[[0]], # block_ids should be list[list[int]]
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
))
|
))
|
||||||
@ -112,14 +112,35 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||||
|
"""Check if the request state block IDs match the block table.
|
||||||
|
|
||||||
|
This function handles both legacy BlockTable and new MultiGroupBlockTable
|
||||||
|
structures for backward compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||||
block_table = model_runner.input_batch.block_table
|
multi_group_block_table = model_runner.input_batch.block_table
|
||||||
req_state = model_runner.requests[req_id]
|
req_state = model_runner.requests[req_id]
|
||||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
|
|
||||||
|
# Access the first block table from MultiGroupBlockTable
|
||||||
|
# This is safe since we currently only use single KV cache groups
|
||||||
|
block_table = multi_group_block_table[0]
|
||||||
|
|
||||||
|
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
|
||||||
|
# Extract the first group's block IDs
|
||||||
|
if isinstance(req_state.block_ids[0], list):
|
||||||
|
# New format: list[list[int]] - extract first group
|
||||||
|
req_block_ids = req_state.block_ids[0]
|
||||||
|
else:
|
||||||
|
# Legacy format: list[int] - use directly
|
||||||
|
req_block_ids = req_state.block_ids
|
||||||
|
|
||||||
|
if block_table.num_blocks_per_row[req_index] != len(req_block_ids):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
block_table_values = block_table.block_table_np[req_index, :num_blocks]
|
||||||
req_state.block_ids).all()
|
return (block_table_values == req_block_ids).all()
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_new_request(model_runner):
|
def test_update_states_new_request(model_runner):
|
||||||
|
|||||||
@ -175,11 +175,21 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.kv_caches: list[torch.Tensor] = []
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||||
# self.input_batch: InputBatch # Persistent batch.
|
|
||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
|
|
||||||
|
# Initialize input batch early to avoid AttributeError in _update_states
|
||||||
|
self.input_batch = InputBatch(
|
||||||
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
|
block_size=self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Cached torch/numpy tensor
|
# Cached torch/numpy tensor
|
||||||
# The pytorch tensor and numpy array share the same buffer.
|
# The pytorch tensor and numpy array share the same buffer.
|
||||||
# Sometimes the numpy op is faster so we create both.
|
# Sometimes the numpy op is faster so we create both.
|
||||||
@ -1286,16 +1296,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"Hybrid models with more than one KV cache type are not "
|
"Hybrid models with more than one KV cache type are not "
|
||||||
"supported yet.")
|
"supported yet.")
|
||||||
|
|
||||||
self.input_batch = InputBatch(
|
if kv_cache_config.kv_cache_groups[
|
||||||
max_num_reqs=self.max_num_reqs,
|
0].kv_cache_spec.block_size != self.block_size:
|
||||||
max_model_len=self.max_model_len,
|
self.input_batch = InputBatch(
|
||||||
max_num_batched_tokens=self.max_num_tokens,
|
max_num_reqs=self.max_num_reqs,
|
||||||
device=self.device,
|
max_model_len=self.max_model_len,
|
||||||
pin_memory=self.pin_memory,
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
device=self.device,
|
||||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
pin_memory=self.pin_memory,
|
||||||
block_size,
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
)
|
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
# Verify dtype compatibility between block_table_cpu and input_batch
|
||||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||||
0].get_cpu_tensor().dtype
|
0].get_cpu_tensor().dtype
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user