afd use ubatch without thread

Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
This commit is contained in:
jiangkuaixue123 2025-12-15 14:38:03 +08:00
parent bd8fe276f5
commit 28cba040c7
3 changed files with 188 additions and 13 deletions

View File

@ -4,7 +4,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, NamedTuple
import torch
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.distributed.afd_transfer import AFDConnectorBase
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
@ -232,6 +233,13 @@ class AFDMetadata:
afd_tokens_lens: list[int] # padded lengths for tensor slicing
num_of_stages: int
input_ids_list: list[torch.Tensor] = field(default_factory=list)
positions_list: list[torch.Tensor] = field(default_factory=list)
inputs_embeds_list: list[torch.Tensor] = field(default_factory=list)
intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list)
attn_metadata_list: list[AttentionMetadata] = field(default_factory=list)
dp_metadata_list: list[DPMetadata] = field(default_factory=list)
@dataclass
class ForwardContext:

View File

@ -1361,7 +1361,9 @@ class DeepseekV2Model(nn.Module):
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_hidden, residual = layer(positions, hidden_states, residual, llama_4_scaling)
current_hidden, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=afd_metadata.afd_stage_idx,
@ -1385,6 +1387,96 @@ class DeepseekV2Model(nn.Module):
return hidden_states, residual
def forward_with_afd_v2(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
afd_metadata: AFDMetadata,
llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context()
recv_handle = None
ubatch_hidden_states = []
ubatch_residual = []
start_idx = 0
for pos in afd_metadata.positions_list:
# DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2
# Otherwise use shape[0] as requested
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
end_idx = start_idx + num_tokens
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
ubatch_residual.append(
residual[start_idx:end_idx] if residual is not None else None
)
start_idx = end_idx
for layer in islice(self.layers, self.start_layer, self.end_layer):
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
logger.info(
f"jcz deepseekv2 forward_with_afd_v2 layer_idx: {layer.layer_idx}, stage_i: {stage_i}"
)
afd_connector = afd_metadata.afd_connector
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
residual = ubatch_residual[stage_i]
if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
else:
hidden_states = ubatch_hidden_states[stage_i]
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_positions = afd_metadata.positions_list[stage_i]
logger.info(
f"jcz deepseekv2 forward_with_afd_v2 hidden_states: {hidden_states.shape}"
f" positions:{positions.shape}"
)
hidden_states, residual = layer(
current_positions, hidden_states, residual, llama_4_scaling
)
ubatch_hidden_states[stage_i] = hidden_states
ubatch_residual[stage_i] = residual
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=stage_i,
seq_len=hidden_states.shape[0],
dtype=hidden_states.dtype,
device=hidden_states.device,
num_of_stages=afd_metadata.num_of_stages,
afd_tokens_lens=afd_metadata.afd_tokens_lens,
)
afd_connector.send_attn_output(hidden_states, metadata)
# Recv last layer and last stage FFN output.
ubatch_hidden_states[afd_metadata.num_of_stages - 1], recv_metadata = (
afd_connector.recv_ffn_output()
)
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
# Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
if any(r is not None for r in ubatch_residual):
residual = torch.cat(ubatch_residual, dim=0)
else:
residual = None
return hidden_states, residual
def forward(
self,
input_ids: torch.Tensor,
@ -1421,12 +1513,14 @@ class DeepseekV2Model(nn.Module):
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
if afd_metadata != None:
hidden_states, residual = self.forward_with_afd(
hidden_states, residual = self.forward_with_afd_v2(
hidden_states, residual, positions, afd_metadata, llama_4_scaling
)
else:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual, llama_4_scaling)
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(

View File

@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
AFDMetadata,
DPMetadata,
create_forward_context,
get_forward_context,
@ -358,6 +359,59 @@ class UBatchWrapper:
return ubatch_metadata
def _make_afd_ubatch_metadata(
self,
ubatch_slices,
attn_metadata,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
dp_metadata,
afd_metadata,
) -> AFDMetadata:
if ubatch_slices is None:
afd_metadata.input_ids_list.append(input_ids)
afd_metadata.positions_list.append(positions)
afd_metadata.inputs_embeds_list.append(inputs_embeds)
afd_metadata.intermediate_tensors_list.append(intermediate_tensors)
afd_metadata.attn_metadata_list.append(attn_metadata)
afd_metadata.dp_metadata_list.append(dp_metadata)
else:
for i, ubatch_slice in enumerate(ubatch_slices):
(
sliced_input_ids,
sliced_positions,
sliced_inputs_embeds,
sliced_intermediate_tensors,
) = self._slice_model_inputs(
ubatch_slice.token_slice,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
)
dp_size = self.vllm_config.parallel_config.data_parallel_size
ubatch_num_tokens_across_dp = torch.tensor(
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
)
ubatch_dp_metadata = DPMetadata.make(
self.vllm_config.parallel_config,
ubatch_slice.num_tokens,
ubatch_num_tokens_across_dp,
)
afd_metadata.input_ids_list.append(sliced_input_ids)
afd_metadata.positions_list.append(sliced_positions)
afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds)
afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors)
afd_metadata.attn_metadata_list.append(
attn_metadata[i] if attn_metadata is not None else None)
afd_metadata.dp_metadata_list.append(ubatch_dp_metadata)
return afd_metadata
def _slice_model_inputs(
self,
tokens_slice: slice,
@ -392,6 +446,33 @@ class UBatchWrapper:
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
afd_metadata = forward_context.afd_metadata
attn_metadata = forward_context.attn_metadata
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
inputs_embeds = kwargs["inputs_embeds"]
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
if self.vllm_config.afd_config:
afd_metadata = self._make_afd_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
dp_metadata=dp_metadata,
afd_metadata=afd_metadata,
)
forward_context.afd_metadata = afd_metadata
if cudagraph_runtime_mode is CUDAGraphMode.NONE:
return self.runnable(*args, **kwargs)
else:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
@ -400,6 +481,7 @@ class UBatchWrapper:
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
@ -411,18 +493,9 @@ class UBatchWrapper:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
inputs_embeds = kwargs["inputs_embeds"]
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
ubatch_dp_metadata = []