diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a715d70d1d5f4..a2f365cf21eb3 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, NamedTuple import torch @@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.distributed.afd_transfer import AFDConnectorBase from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices @@ -232,6 +233,13 @@ class AFDMetadata: afd_tokens_lens: list[int] # padded lengths for tensor slicing num_of_stages: int + input_ids_list: list[torch.Tensor] = field(default_factory=list) + positions_list: list[torch.Tensor] = field(default_factory=list) + inputs_embeds_list: list[torch.Tensor] = field(default_factory=list) + intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list) + attn_metadata_list: list[AttentionMetadata] = field(default_factory=list) + dp_metadata_list: list[DPMetadata] = field(default_factory=list) + @dataclass class ForwardContext: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cd3829b9220c9..eb9d85cfc55fd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1361,7 +1361,9 @@ class DeepseekV2Model(nn.Module): if recv_handle is not None: for work in recv_handle: work.wait() - current_hidden, residual = layer(positions, hidden_states, residual, llama_4_scaling) + current_hidden, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=afd_metadata.afd_stage_idx, @@ -1385,6 +1387,96 @@ class DeepseekV2Model(nn.Module): return hidden_states, residual + def forward_with_afd_v2( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + llama_4_scaling: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() + recv_handle = None + + ubatch_hidden_states = [] + ubatch_residual = [] + + start_idx = 0 + for pos in afd_metadata.positions_list: + # DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2 + # Otherwise use shape[0] as requested + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None + ) + start_idx = end_idx + + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + logger.info( + f"jcz deepseekv2 forward_with_afd_v2 layer_idx: {layer.layer_idx}, stage_i: {stage_i}" + ) + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] + + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + else: + hidden_states = ubatch_hidden_states[stage_i] + + if recv_handle is not None: + for work in recv_handle: + work.wait() + + current_positions = afd_metadata.positions_list[stage_i] + logger.info( + f"jcz deepseekv2 forward_with_afd_v2 hidden_states: {hidden_states.shape}" + f" positions:{positions.shape}" + ) + hidden_states, residual = layer( + current_positions, hidden_states, residual, llama_4_scaling + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer and last stage FFN output. + ubatch_hidden_states[afd_metadata.num_of_stages - 1], recv_metadata = ( + afd_connector.recv_ffn_output() + ) + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + if recv_handle is not None: + for work in recv_handle: + work.wait() + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -1421,12 +1513,14 @@ class DeepseekV2Model(nn.Module): afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None if afd_metadata != None: - hidden_states, residual = self.forward_with_afd( + hidden_states, residual = self.forward_with_afd_v2( hidden_states, residual, positions, afd_metadata, llama_4_scaling ) else: for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual, llama_4_scaling) + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) if not get_pp_group().is_last_rank: return IntermediateTensors( diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 21694b367c326..9e17c718c5513 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import ( + AFDMetadata, DPMetadata, create_forward_context, get_forward_context, @@ -358,6 +359,59 @@ class UBatchWrapper: return ubatch_metadata + def _make_afd_ubatch_metadata( + self, + ubatch_slices, + attn_metadata, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + dp_metadata, + afd_metadata, + ) -> AFDMetadata: + if ubatch_slices is None: + afd_metadata.input_ids_list.append(input_ids) + afd_metadata.positions_list.append(positions) + afd_metadata.inputs_embeds_list.append(inputs_embeds) + afd_metadata.intermediate_tensors_list.append(intermediate_tensors) + afd_metadata.attn_metadata_list.append(attn_metadata) + afd_metadata.dp_metadata_list.append(dp_metadata) + else: + for i, ubatch_slice in enumerate(ubatch_slices): + ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) = self._slice_model_inputs( + ubatch_slice.token_slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ) + + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata = DPMetadata.make( + self.vllm_config.parallel_config, + ubatch_slice.num_tokens, + ubatch_num_tokens_across_dp, + ) + + afd_metadata.input_ids_list.append(sliced_input_ids) + afd_metadata.positions_list.append(sliced_positions) + afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) + afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors) + afd_metadata.attn_metadata_list.append( + attn_metadata[i] if attn_metadata is not None else None) + afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) + + return afd_metadata + def _slice_model_inputs( self, tokens_slice: slice, @@ -392,6 +446,33 @@ class UBatchWrapper: cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode afd_metadata = forward_context.afd_metadata + attn_metadata = forward_context.attn_metadata + input_ids = kwargs["input_ids"] + positions = kwargs["positions"] + intermediate_tensors = kwargs["intermediate_tensors"] + inputs_embeds = kwargs["inputs_embeds"] + compute_stream = torch.cuda.current_stream() + + dp_metadata = forward_context.dp_metadata + + if self.vllm_config.afd_config: + afd_metadata = self._make_afd_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + dp_metadata=dp_metadata, + afd_metadata=afd_metadata, + ) + forward_context.afd_metadata = afd_metadata + if cudagraph_runtime_mode is CUDAGraphMode.NONE: + return self.runnable(*args, **kwargs) + else: + assert self.cudagraph_wrapper is not None + return self.cudagraph_wrapper(*args, **kwargs) + # If there's no ubatching, just run the runnable object if ubatch_slices is None: # This is to account for the case where ubatching was aborted. @@ -400,6 +481,7 @@ class UBatchWrapper: # num_tokens, we don't have a non-ubatched one. Without this # check, the cudagraph wrapper will try to capture a cudagraph # for this shape during a normal run. + if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None if batch_descriptor.num_tokens in self.cudagraphs: @@ -411,18 +493,9 @@ class UBatchWrapper: assert self.cudagraph_wrapper is not None return self.cudagraph_wrapper(*args, **kwargs) - attn_metadata = forward_context.attn_metadata num_tokens = ( ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start ) * 2 - input_ids = kwargs["input_ids"] - positions = kwargs["positions"] - intermediate_tensors = kwargs["intermediate_tensors"] - inputs_embeds = kwargs["inputs_embeds"] - compute_stream = torch.cuda.current_stream() - - dp_metadata = forward_context.dp_metadata - # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None ubatch_dp_metadata = []