[V1] Remove num_input_tokens from attn_metadata (#17193)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-04-30 00:28:41 +08:00 committed by GitHub
parent 2ef5d106bb
commit 24e6ad3f16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 14 additions and 21 deletions

View File

@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
if attn_metadata is not None:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
if attn_metadata is not None and hasattr(attn_metadata,
"num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends or no attn_metadata
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
batchsize = num_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch

View File

@ -94,9 +94,6 @@ class FlashAttentionMetadata:
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# for local attention
@dataclass
class LocalAttentionMetadata:

View File

@ -183,9 +183,6 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.

View File

@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: Optional[int] = None

View File

@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
input_ids=input_ids,
positions=positions,

View File

@ -769,7 +769,10 @@ class TPUModelRunner:
xm.mark_step()
num_reqs = self.input_batch.num_reqs
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=scheduler_output.total_num_scheduled_tokens):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,