mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:35:50 +08:00
[V1] Remove num_input_tokens from attn_metadata (#17193)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
2ef5d106bb
commit
24e6ad3f16
@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = attn_metadata.num_input_tokens
|
||||
if attn_metadata is not None and hasattr(attn_metadata,
|
||||
"num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = batchsize
|
||||
@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = attn_metadata.num_input_tokens
|
||||
batchsize = num_tokens
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
|
||||
@ -94,9 +94,6 @@ class FlashAttentionMetadata:
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
|
||||
@ -183,9 +183,6 @@ class FlashInferMetadata:
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
@property
|
||||
def query_start_loc(self):
|
||||
# The GPUModelRunner expects to be able to access this property.
|
||||
|
||||
@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
|
||||
@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -769,7 +769,10 @@ class TPUModelRunner:
|
||||
xm.mark_step()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=scheduler_output.total_num_scheduled_tokens):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user