mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 06:47:03 +08:00
minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
efcb786d52
commit
e696f78e05
@ -27,7 +27,7 @@ class InputBatch:
|
|||||||
# batch_idx -> num_scheduled_tokens
|
# batch_idx -> num_scheduled_tokens
|
||||||
num_scheduled_tokens: np.ndarray
|
num_scheduled_tokens: np.ndarray
|
||||||
total_num_tokens: int
|
total_num_tokens: int
|
||||||
max_num_tokens: int
|
max_query_len: int
|
||||||
num_reqs: int
|
num_reqs: int
|
||||||
|
|
||||||
attn_metadata: dict[str, Any]
|
attn_metadata: dict[str, Any]
|
||||||
|
|||||||
@ -577,7 +577,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Get the number of scheduled tokens for each request.
|
# Get the number of scheduled tokens for each request.
|
||||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||||
max_num_scheduled_tokens = max(tokens)
|
max_num_scheduled_tokens = int(num_scheduled_tokens.max())
|
||||||
|
|
||||||
prepare_inputs(
|
prepare_inputs(
|
||||||
idx_mapping=idx_mapping_np,
|
idx_mapping=idx_mapping_np,
|
||||||
@ -751,7 +751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
idx_mapping_np=idx_mapping_np,
|
idx_mapping_np=idx_mapping_np,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
total_num_tokens=total_num_scheduled_tokens,
|
total_num_tokens=total_num_scheduled_tokens,
|
||||||
max_num_tokens=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
spec_decode_metadata=spec_decode_metadata,
|
spec_decode_metadata=spec_decode_metadata,
|
||||||
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
|
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
|
||||||
@ -1412,10 +1412,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
num_input_tokens, intermediate_tensors, True)
|
num_input_tokens, intermediate_tensors, True)
|
||||||
|
|
||||||
uniform_decode = (input_batch.max_num_tokens
|
uniform_decode = (input_batch.max_query_len
|
||||||
== self.uniform_decode_query_len
|
== self.uniform_decode_query_len
|
||||||
and num_scheduled_tokens
|
and num_scheduled_tokens
|
||||||
== input_batch.num_reqs * input_batch.max_num_tokens)
|
== input_batch.num_reqs * input_batch.max_query_len)
|
||||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||||
uniform_decode=uniform_decode)
|
uniform_decode=uniform_decode)
|
||||||
cudagraph_runtime_mode, batch_descriptor = \
|
cudagraph_runtime_mode, batch_descriptor = \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user