mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 01:45:31 +08:00
[V1] Remove num_input_tokens from attn_metadata (#17193)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
2ef5d106bb
commit
24e6ad3f16
@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None and hasattr(attn_metadata,
|
||||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
"num_prefill_tokens"):
|
||||||
# for v0 attention backends
|
# for v0 attention backends
|
||||||
batchsize = attn_metadata.num_prefill_tokens + \
|
batchsize = attn_metadata.num_prefill_tokens + \
|
||||||
attn_metadata.num_decode_tokens
|
attn_metadata.num_decode_tokens
|
||||||
else:
|
else:
|
||||||
# for v1 attention backends
|
# for v1 attention backends or no attn_metadata
|
||||||
batchsize = attn_metadata.num_input_tokens
|
|
||||||
else:
|
|
||||||
batchsize = num_tokens
|
batchsize = num_tokens
|
||||||
num_tokens_across_dp = [0] * dp_size
|
num_tokens_across_dp = [0] * dp_size
|
||||||
num_tokens_across_dp[dp_rank] = batchsize
|
num_tokens_across_dp[dp_rank] = batchsize
|
||||||
@ -124,7 +122,7 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
attn_metadata.num_decode_tokens
|
attn_metadata.num_decode_tokens
|
||||||
else:
|
else:
|
||||||
# for v1 attention backends
|
# for v1 attention backends
|
||||||
batchsize = attn_metadata.num_input_tokens
|
batchsize = num_tokens
|
||||||
# we use synchronous scheduling right now,
|
# we use synchronous scheduling right now,
|
||||||
# adding a sync point here should not affect
|
# adding a sync point here should not affect
|
||||||
# scheduling of the next batch
|
# scheduling of the next batch
|
||||||
|
|||||||
@ -94,9 +94,6 @@ class FlashAttentionMetadata:
|
|||||||
scheduler_metadata: Optional[torch.Tensor] = None
|
scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For logging.
|
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
||||||
|
|
||||||
# for local attention
|
# for local attention
|
||||||
@dataclass
|
@dataclass
|
||||||
class LocalAttentionMetadata:
|
class LocalAttentionMetadata:
|
||||||
|
|||||||
@ -183,9 +183,6 @@ class FlashInferMetadata:
|
|||||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||||
|
|
||||||
# For logging.
|
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def query_start_loc(self):
|
def query_start_loc(self):
|
||||||
# The GPUModelRunner expects to be able to access this property.
|
# The GPUModelRunner expects to be able to access this property.
|
||||||
|
|||||||
@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]):
|
|||||||
num_decode_tokens: int
|
num_decode_tokens: int
|
||||||
num_prefills: int
|
num_prefills: int
|
||||||
|
|
||||||
# For logging.
|
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
||||||
|
|
||||||
# The dimension of the attention heads
|
# The dimension of the attention heads
|
||||||
head_dim: Optional[int] = None
|
head_dim: Optional[int] = None
|
||||||
|
|
||||||
|
|||||||
@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_scheduled_tokens
|
num_input_tokens = num_scheduled_tokens
|
||||||
attn_metadata.num_input_tokens = num_input_tokens
|
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# modal outputs after that to ensure the correct order
|
# modal outputs after that to ensure the correct order
|
||||||
@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Run the decoder.
|
# Run the decoder.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens):
|
||||||
output = self.model(
|
output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
|
|||||||
@ -769,7 +769,10 @@ class TPUModelRunner:
|
|||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
# Run the decoder
|
# Run the decoder
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(
|
||||||
|
attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=scheduler_output.total_num_scheduled_tokens):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.position_ids,
|
positions=self.position_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user