[Cleanup] Remove unused ModelRunner V1 InputBatch.num_tokens field (#30218)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-12-18 09:17:00 -08:00 committed by GitHub
parent f4ee2c3d90
commit 686cbaac64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 36 deletions

View File

@ -128,7 +128,6 @@ class InputBatch:
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
@ -340,9 +339,6 @@ class InputBatch:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
@ -522,10 +518,6 @@ class InputBatch:
self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1],
)
self.num_tokens[i1], self.num_tokens[i2] = (
self.num_tokens[i2],
self.num_tokens[i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1],
@ -661,17 +653,16 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
if last_req_index != empty_index:
(
self.spec_token_ids[last_req_index],
self.spec_token_ids[empty_index],
) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens_no_spec[last_req_index] + len(
self.spec_token_ids[last_req_index]
)
(self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
@ -682,7 +673,6 @@ class InputBatch:
self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
last_req_index
)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index
]

View File

@ -923,7 +923,6 @@ class GPUModelRunner(
self.input_batch.num_prompt_tokens[req_index]
+ num_output_tokens
)
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx
# Update the block IDs.
@ -968,7 +967,6 @@ class GPUModelRunner(
req_index, start_token_index:end_token_index
] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
@ -984,8 +982,6 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index
] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
# When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not
@ -2702,7 +2698,6 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = req_ids[req_idx]
req_state = self.requests[req_id]

View File

@ -51,7 +51,6 @@ class InputBatch:
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
@ -200,9 +199,6 @@ class InputBatch:
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
@ -344,10 +340,6 @@ class InputBatch:
self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1],
)
self.num_tokens[i1], self.num_tokens[i2] = (
self.num_tokens[i2],
self.num_tokens[i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1],
@ -448,11 +440,10 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
num_tokens = self.num_tokens_no_spec[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index
]

View File

@ -1283,7 +1283,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1
self.input_batch.num_tokens_no_spec[i] += 1
else:
valid_mask = selected_token_ids != INVALID_TOKEN_ID
@ -1291,7 +1291,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
valid_sampled_token_ids = [
seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
self.input_batch.num_tokens_no_spec[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[i, target_slice] = (