From bd8fe276f5abc630c7ca659151952cfe51304939 Mon Sep 17 00:00:00 2001 From: jiangkuaixue123 Date: Wed, 10 Dec 2025 10:08:40 +0800 Subject: [PATCH 01/19] 1.add afd 2.support afd with DBO. 3.support AFDP2PConnector 4.support afd with deepseekv2 Signed-off-by: jiangkuaixue123 --- .../online_serving/afd_deepseek_v2/README.md | 19 + vllm/config/__init__.py | 3 + vllm/config/afd.py | 79 ++++ vllm/config/vllm.py | 28 +- vllm/distributed/afd_transfer/__init__.py | 12 + .../afd_transfer/afd_connector/__init__.py | 13 + .../afd_transfer/afd_connector/base.py | 139 ++++++ .../afd_connector/dummy_connector.py | 211 +++++++++ .../afd_transfer/afd_connector/factory.py | 95 ++++ .../afd_transfer/afd_connector/metadata.py | 175 +++++++ .../afd_connector/p2p_connector.py | 304 ++++++++++++ vllm/distributed/parallel_state.py | 66 ++- vllm/engine/arg_utils.py | 8 +- vllm/entrypoints/afd_ffn_server.py | 91 ++++ vllm/entrypoints/cli/fserver.py | 51 ++ vllm/entrypoints/cli/main.py | 2 + vllm/forward_context.py | 56 +++ vllm/model_executor/models/deepseek_v2.py | 246 ++++++++-- vllm/v1/worker/gpu_ffn_model_runner.py | 441 ++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 83 +++- vllm/v1/worker/gpu_ubatch_wrapper.py | 10 +- vllm/v1/worker/gpu_worker.py | 68 ++- 22 files changed, 2142 insertions(+), 58 deletions(-) create mode 100644 examples/online_serving/afd_deepseek_v2/README.md create mode 100644 vllm/config/afd.py create mode 100644 vllm/distributed/afd_transfer/__init__.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/__init__.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/base.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/dummy_connector.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/factory.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/metadata.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/p2p_connector.py create mode 100644 vllm/entrypoints/afd_ffn_server.py create mode 100644 vllm/entrypoints/cli/fserver.py create mode 100644 vllm/v1/worker/gpu_ffn_model_runner.py diff --git a/examples/online_serving/afd_deepseek_v2/README.md b/examples/online_serving/afd_deepseek_v2/README.md new file mode 100644 index 0000000000000..4f8cf4a4e40db --- /dev/null +++ b/examples/online_serving/afd_deepseek_v2/README.md @@ -0,0 +1,19 @@ +# P2P Connector +P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances. + +When the --enable-dbo flag is currently enabled, the num_stage parameter becomes ineffective, and the actual number of microbatches is 2. + +Currently, the P2PConnector only supports scenarios where the number of dies of A equals that of F. Asymmetric configurations will be supported in future updates. + +1. Attn + +``` +vllm serve "/path/to/DeepSeek-V2-Lite" --data_parallel_size=2 --enable_expert_parallel --enforce_eager --enable-dbo --dbo-prefill-token-threshold 12 --dbo-decode-token-threshold 2 --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"2","afd_extra_config":{"afd_size":"2A2F"}}' + +``` + +2. FFN + +``` +vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"2", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}' +``` \ No newline at end of file diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 0e91dd57420a8..719f78880e211 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.config.attention import AttentionConfig +from vllm.config.afd import AFDConfig from vllm.config.cache import CacheConfig from vllm.config.compilation import ( CompilationConfig, @@ -65,6 +66,8 @@ __all__ = [ "KVEventsConfig", # From vllm.config.kv_transfer "KVTransferConfig", + # AFD (Attention FFN Disaggregation) configuration + "AFDConfig", # From vllm.config.load "LoadConfig", # From vllm.config.lora diff --git a/vllm/config/afd.py b/vllm/config/afd.py new file mode 100644 index 0000000000000..7277b568f72bd --- /dev/null +++ b/vllm/config/afd.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import Any, Literal + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class AFDConfig: + """Configuration for AFD (Attention FFN Disaggregation) distributed + computation.""" + + afd_connector: str = "dummy" + """The AFD connector for vLLM to communicate between attention and FFN + nodes. Available connectors: 'dummy', 'p2pconnector'""" + + afd_role: Literal["attention", "ffn"] = "attention" + """Role of this vLLM instance in AFD. 'attention' for attention workers, + 'ffn' for FFN servers.""" + + afd_port: int = 1239 + """Port number for stepmesh parameter server communication.""" + + afd_host: str = "127.0.0.1" + """Host address for stepmesh parameter server communication.""" + + num_afd_stages: int = 3 + """Number of pipeline stages for stage parallelism.""" + + num_attention_servers: int = 1 + """Number of attention servers.""" + + num_ffn_servers: int = 1 + """Number of FFN servers.""" + + afd_server_rank: int = 0 + """Rank of this AFD server.""" + + afd_extra_config: dict[str, Any] = field(default_factory=dict) + """Extra configuration for specific AFD connectors.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # AFD configuration affects the computation graph structure + # as it changes how FFN computation is performed + factors: list[Any] = [ + self.afd_connector, + self.afd_role, + self.num_afd_stages, + self.num_attention_servers, + self.num_ffn_servers, + ] + return hashlib.sha256(str(factors).encode()).hexdigest() + + @property + def is_attention_server(self) -> bool: + """Check if this instance is configured as an attention server.""" + return self.afd_role == "attention" + + @property + def is_ffn_server(self) -> bool: + """Check if this instance is configured as an FFN server.""" + return self.afd_role == "ffn" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0439dc52e7e6f..610473e9a8e6d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -28,6 +28,7 @@ from vllm.utils import random_uuid from vllm.utils.hashing import safe_hash from .attention import AttentionConfig +from .afd import AFDConfig from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig @@ -227,6 +228,8 @@ class VllmConfig: """The configurations for event publishing.""" ec_transfer_config: ECTransferConfig | None = None """The configurations for distributed EC cache transfer.""" + afd_config: AFDConfig | None = None + """AFD (Attention FFN Disaggregation) configuration.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -318,6 +321,10 @@ class VllmConfig: vllm_factors.append(self.ec_transfer_config.compute_hash()) else: vllm_factors.append("None") + if self.afd_config: + vllm_factors.append(self.afd_config.compute_hash()) + else: + vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): additional_config_hash = safe_hash( @@ -872,16 +879,17 @@ class VllmConfig: if self.parallel_config.use_ubatching: a2a_backend = self.parallel_config.all2all_backend - assert a2a_backend in [ - "deepep_low_latency", - "deepep_high_throughput", - ], ( - "Microbatching currently only supports the deepep_low_latency and " - f"deepep_high_throughput all2all backend. {a2a_backend} is not " - "supported. To fix use --all2all-backend=deepep_low_latency or " - "--all2all-backend=deepep_high_throughput and install the DeepEP" - " kernels." - ) + if self.afd_config is None: + assert a2a_backend in [ + "deepep_low_latency", + "deepep_high_throughput", + ], ( + "Microbatching currently only supports the deepep_low_latency and " + f"deepep_high_throughput all2all backend. {a2a_backend} is not " + "supported. To fix use --all2all-backend=deepep_low_latency or " + "--all2all-backend=deepep_high_throughput and install the DeepEP" + " kernels." + ) if not self.model_config.disable_cascade_attn: self.model_config.disable_cascade_attn = True diff --git a/vllm/distributed/afd_transfer/__init__.py b/vllm/distributed/afd_transfer/__init__.py new file mode 100644 index 0000000000000..fc90a145a44fa --- /dev/null +++ b/vllm/distributed/afd_transfer/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD (Attention-FFN Disaggregation) transfer components. + +This module provides the distributed infrastructure for AFD, enabling +disaggregated FFN computation across different machines while keeping +attention computation local. +""" + +from .afd_connector import AFDConnectorBase, AFDConnectorFactory, AFDConnectorMetadata + +__all__ = ["AFDConnectorBase", "AFDConnectorMetadata", "AFDConnectorFactory"] diff --git a/vllm/distributed/afd_transfer/afd_connector/__init__.py b/vllm/distributed/afd_transfer/afd_connector/__init__.py new file mode 100644 index 0000000000000..56f6e91160af0 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD connector implementations for different transport backends.""" + +from .base import AFDConnectorBase +from .factory import AFDConnectorFactory +from .metadata import AFDConnectorMetadata + +__all__ = [ + "AFDConnectorBase", + "AFDConnectorFactory", + "AFDConnectorMetadata", +] diff --git a/vllm/distributed/afd_transfer/afd_connector/base.py b/vllm/distributed/afd_transfer/afd_connector/base.py new file mode 100644 index 0000000000000..97ebaeacaf86e --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/base.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AFDConnectorBase Class for Distributed AFD FFN computation + +The class provides the four core AFD communication interfaces: +1. send_attn_output(): Send attention output to FFN servers (Attention Worker) +2. recv_ffn_output(): Receive FFN computation result (Attention Worker) +3. recv_attn_output(): Receive attention output from workers (FFN Server) +4. send_ffn_output(): Send FFN computation result back (FFN Server) +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import torch + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + from .metadata import AFDConnectorMetadata + + +class AFDConnectorBase(ABC): + """ + Abstract base class for AFD connectors. + + This provides the four core interfaces for AFD communication between + attention workers and FFN servers. + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize the AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the connector and release resources.""" + raise NotImplementedError + + @abstractmethod + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + raise NotImplementedError + + @property + @abstractmethod + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + raise NotImplementedError + + def get_connector_rank(self) -> int: + """Get the rank of this connector.""" + return getattr(self, "rank", 0) + + def get_connector_local_rank(self) -> int: + """Get the local rank of this connector.""" + return getattr(self, "local_rank", 0) + + @abstractmethod + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> Any: + """Send attention output to FFN servers. + + Args: + hidden_states: Attention output tensor + metadata: AFD metadata containing layer_idx, stage_idx, seq_len info + + Returns: + Any: Handle for tracking this request (backend-specific) + """ + raise NotImplementedError + + @abstractmethod + def recv_ffn_output( + self, + handle: Any, + ) -> torch.Tensor: + """Wait for and receive FFN computation result. + + Args: + handle: Handle returned by send_attn_output() + + Returns: + torch.Tensor: FFN computation result + """ + raise NotImplementedError + + @abstractmethod + def recv_attn_output( + self, + timeout_ms: int | None = None, + ) -> tuple[torch.Tensor, "AFDConnectorMetadata"]: + """Receive attention output from attention workers. + + Args: + timeout_ms: Optional timeout in milliseconds + + Returns: + tuple: (hidden_states, metadata) + - hidden_states: Concatenated attention outputs + - metadata: Inferred AFD metadata containing + seq_lens and other info + """ + raise NotImplementedError + + @abstractmethod + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> None: + """Send FFN computation result back to attention workers. + + Args: + ffn_output: Computed FFN result + metadata: AFD metadata containing seq_lens + for splitting and routing info + """ + raise NotImplementedError diff --git a/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py b/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py new file mode 100644 index 0000000000000..c462646f9d0e9 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Dummy AFD Connector for testing and local development. + +This connector provides a no-op AFDConnectorBase interface, +useful for testing and development scenarios where actual +distributed FFN computation is not needed. +""" + +import time +from collections import deque +from typing import TYPE_CHECKING, Any + +import torch + +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class DummyAFDConnector(AFDConnectorBase): + """Dummy AFD connector that returns zero tensors. + + This connector is useful for: + 1. Testing AFD infrastructure without actual remote computation + 2. Development scenarios where FFN computation should be disabled + 3. Fallback behavior when remote FFN servers are unavailable + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize the dummy AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + """ + self.afd_config = config.afd_config + self.rank = rank + self.local_rank = local_rank + self._is_initialized = False + self.hidden_size = config.model_config.hf_config.hidden_size + self.num_stages = config.afd_config.num_afd_stages + + self.events: deque = deque(maxlen=self.num_stages) + + logger.info("DummyAFDConnector initialized for rank %s", rank) + + self.init_afd_connector() + + def init_afd_connector(self) -> None: + """Initialize the dummy connector. + + This is a no-op for the dummy connector. + """ + if self._is_initialized: + return + + logger.info("Initializing DummyAFDConnector (no-op)") + self._is_initialized = True + + def close(self) -> None: + """Close the dummy connector. + + This is a no-op for the dummy connector. + """ + if not self._is_initialized: + return + + logger.info("Closing DummyAFDConnector (no-op)") + self._is_initialized = False + + @property + def is_initialized(self) -> bool: + """Check if the connector is initialized. + + Returns: + bool: True if initialized, False otherwise + """ + return self._is_initialized + + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> Any: + """ + Send attention output to FFN servers (dummy implementation). + """ + logger.debug( + "DummyAFDConnector: send_attn_output layer=%s, stage=%s", + metadata.layer_idx, + metadata.stage_idx, + ) + + # Validate metadata consistency + if not metadata.validate_tensor_shape(hidden_states.shape): + raise ValueError( + "Tensor shape %s doesn't match metadata %s", + hidden_states.shape, + metadata, + ) + + if not metadata.is_single_sequence: + raise ValueError("Attention side should have single sequence") + + self.events.append((None, metadata)) + + return None + + def recv_ffn_output( + self, + timeout_ms: float | None = None, + ) -> torch.Tensor: + """Receive FFN computation result (dummy implementation).""" + logger.debug("DummyAFDConnector: recv_ffn_output timeout_ms=%s", timeout_ms) + + _, metadata = self.events.popleft() + seq_len = metadata.seq_lens[0] # Single sequence for attention side + return torch.zeros( + seq_len, + self.hidden_size, + dtype=metadata.dtype, + device=metadata.device, + ) + + def recv_attn_output( + self, + timeout_ms: int | None = None, + ) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Receive attention output from attention workers (dummy implementation). + """ + logger.debug("DummyAFDConnector: recv_attn_output timeout_ms=%s", timeout_ms) + + # Generate dummy data that simulates multiple attention workers + dummy_seq_lens = [ + 2, + 2, + 2, + ] # Variable sequence lengths from different workers + total_tokens = sum(dummy_seq_lens) + + dummy_tensor = torch.zeros( + total_tokens, self.hidden_size, dtype=torch.bfloat16, device="cuda" + ) + + # Create dummy metadata + dummy_metadata = AFDConnectorMetadata.create_ffn_metadata( + layer_idx=0, # Dummy layer + stage_idx=0, # Dummy stage + dtype=torch.bfloat16, + device=torch.device("cuda"), + seq_lens=dummy_seq_lens, + request_id=f"dummy_ffn_batch_{time.time()}", + ) + + # Cache metadata for send_ffn_output + self._current_metadata = dummy_metadata + time.sleep(1) + + return dummy_tensor, dummy_metadata + + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """Send FFN computation result back (dummy implementation).""" + logger.debug( + "DummyAFDConnector: send_ffn_output layer=%s, stage=%s", + metadata.layer_idx, + metadata.stage_idx, + ) + + # Validate that ffn_output shape matches metadata + if not metadata.validate_tensor_shape(ffn_output.shape): + logger.warning( + "FFN output shape %s doesn't match metadata %s", + ffn_output.shape, + metadata, + ) + + # Log the splitting information for debugging + logger.debug( + "DummyAFDConnector: Split FFN output into %s parts with lengths %s", + metadata.num_sequences, + metadata.seq_lens, + ) + + # Simulate splitting (for logging purposes) + if metadata.get_split_indices(): + split_outputs = torch.split(ffn_output, metadata.seq_lens, dim=0) + logger.debug( + "DummyAFDConnector: Split shapes: %s", + [s.shape for s in split_outputs], + ) + + time.sleep(1) + # No-op for dummy connector - just log the operation diff --git a/vllm/distributed/afd_transfer/afd_connector/factory.py b/vllm/distributed/afd_transfer/afd_connector/factory.py new file mode 100644 index 0000000000000..757c0ed9fc661 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/factory.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Factory for creating AFD connectors based on configuration.""" + +import importlib +from collections.abc import Callable +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +from .base import AFDConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class AFDConnectorFactory: + _registry: dict[str, Callable[[], type[AFDConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[AFDConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector( + cls, rank: int, local_rank: int, config: "VllmConfig" + ) -> AFDConnectorBase: + """Create an AFD connector based on the configuration. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + + Returns: + AFDConnectorBase: The created connector instance + + Raises: + ValueError: If the transport backend is not supported + ImportError: If required dependencies are not available + """ + afd_config = config.afd_config + connector_name = afd_config.afd_connector + + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, AFDConnectorBase) + return connector_cls(rank, local_rank, config) + + @classmethod + def get_connector_class(cls, connector_name: str) -> type[AFDConnectorBase]: + """Get the connector class for a given connector name. + + Args: + connector_name: The connector name + + Returns: + type[AFDConnectorBase]: The connector class + + Raises: + ValueError: If the connector name is not supported + """ + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + return cls._registry[connector_name]() + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. + +AFDConnectorFactory.register_connector( + "dummy", + "vllm.distributed.afd_transfer.afd_connector.dummy_connector", + "DummyAFDConnector", +) + +AFDConnectorFactory.register_connector( + "p2pconnector", + "vllm.distributed.afd_transfer.afd_connector.p2p_connector", + "P2PAFDConnector", +) diff --git a/vllm/distributed/afd_transfer/afd_connector/metadata.py b/vllm/distributed/afd_transfer/afd_connector/metadata.py new file mode 100644 index 0000000000000..af259ebd6691c --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/metadata.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD metadata definitions for communication between attention and +FFN workers.""" + +import time +import typing +from dataclasses import dataclass, field +from typing import Any + +import torch + + +class FFNNeedForwardData: + def __init__( + self, + moe_comm_method: typing.Any, + num_input_tokens: int, + with_prefill: bool, + total_num_scheduled_tokens: int | None, + is_dummy_run: bool = False, + ): + self.moe_comm_method = moe_comm_method + self.num_input_tokens = num_input_tokens + self.with_prefill = with_prefill + self.total_num_scheduled_tokens = total_num_scheduled_tokens + self.is_dummy_run = is_dummy_run + + +@dataclass +class AFDConnectorMetadata: + """Lightweight AFD metadata containing core information needed for + communication.""" + + layer_idx: int + stage_idx: int + seq_lens: list[int] # Length of each sequence, supports variable length and + # multiple sequences + dtype: torch.dtype + device: torch.device + topk_idx: torch.Tensor | None = None # indices token which expert to be sended + topk_weights: torch.Tensor | None = None # the expert weights + moe_expert_num: int | None = None # number of moe experts + shared_expert_num: int | None = None # number of share experts + scale: torch.Tensor | None = None # quant scale + expertTokenNumsOut: torch.Tensor | None = ( + None # The number of tokens received by each expert is used as input for the subsequent GMM. + ) + recv_handle_list: list[Any] | None = ( + None # the communication handles (list of Work objects returned by torch.distributed.irecv) + ) + + # Optional fields for debugging and extensibility + request_id: str | None = None + timestamp: float | None = None + """ffn need forward data""" + ffn_need_forward_data: FFNNeedForwardData | None = None + num_of_stages: int = 1 + afd_tokens_lens: list = field(default_factory=list) + + def __post_init__(self): + """Validate data consistency.""" + if not self.seq_lens: + raise ValueError("seq_lens cannot be empty") + if any(length <= 0 for length in self.seq_lens): + raise ValueError("All sequence lengths must be positive") + + @property + def total_tokens(self) -> int: + """Total number of tokens.""" + return sum(self.seq_lens) + + @property + def num_sequences(self) -> int: + """Number of sequences.""" + return len(self.seq_lens) + + @property + def is_single_sequence(self) -> bool: + """Whether this is a single sequence (attention side characteristic).""" + return len(self.seq_lens) == 1 + + @property + def is_multi_sequence(self) -> bool: + """Whether this is multiple sequences (FFN side characteristic).""" + return len(self.seq_lens) > 1 + + @classmethod + def create_attention_metadata( + cls, + layer_idx: int, + stage_idx: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + request_id: str | None = None, + ffn_need_forward_data: FFNNeedForwardData | None = None, + num_of_stages: int = 1, + afd_tokens_lens: list[int] = [], + ) -> "AFDConnectorMetadata": + """Create metadata for attention side (single sequence).""" + return cls( + layer_idx=layer_idx, + stage_idx=stage_idx, + seq_lens=[seq_len], + dtype=dtype, + device=device, + request_id=request_id, + ffn_need_forward_data=ffn_need_forward_data, + timestamp=time.time(), + num_of_stages=num_of_stages, + afd_tokens_lens=afd_tokens_lens, + ) + + @classmethod + def create_ffn_metadata( + cls, + layer_idx: int, + stage_idx: int, + seq_lens: list[int], + dtype: torch.dtype, + device: torch.device, + request_id: str | None = None, + ) -> "AFDConnectorMetadata": + """Create metadata for FFN side (multiple sequences).""" + return cls( + layer_idx=layer_idx, + stage_idx=stage_idx, + seq_lens=seq_lens.copy(), # Prevent external modification + dtype=dtype, + device=device, + request_id=request_id, + timestamp=time.time(), + ) + + def get_split_indices(self) -> list[int]: + """Get tensor split indices for FFN side output splitting.""" + if len(self.seq_lens) <= 1: + return [] + + indices = [] + cumsum = 0 + for length in self.seq_lens[:-1]: # Exclude the last one + cumsum += length + indices.append(cumsum) + return indices + + def validate_tensor_shape(self, tensor_shape: tuple[int, ...]) -> bool: + """Validate if tensor shape is consistent with metadata.""" + if len(tensor_shape) < 1: + return False + return tensor_shape[0] == self.total_tokens + + def to_dict(self) -> dict: + """Convert to dictionary format for serialization and debugging.""" + return { + "layer_idx": self.layer_idx, + "stage_idx": self.stage_idx, + "seq_lens": self.seq_lens, + "dtype": self.dtype, + "device": self.device, + "total_tokens": self.total_tokens, + "num_sequences": self.num_sequences, + "request_id": self.request_id, + "timestamp": self.timestamp, + } + + def __repr__(self) -> str: + """Friendly string representation.""" + return ( + f"AFDConnectorMetadata(layer={self.layer_idx}, " + f"stage={self.stage_idx}, seq_lens={self.seq_lens}, " + f"total_tokens={self.total_tokens}, dtype={self.dtype}, " + f"device={self.device}, request_id={self.request_id})" + ) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py new file mode 100644 index 0000000000000..d9d9730b55365 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg + +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + init_afd_process_group, + init_model_parallel_group, +) +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self.config = config + self._initialized: bool = False + self._need_recv_metadata: bool = True + self._tensor_metadata_list: dict[int, TensorMetadata] = {} + self._current_afd_connector_metadata: AFDConnectorMetadata | None = None + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.num_hidden_layers + ) + self.recv_attn_output_counter: int = 0 + self.recv_ffn_output_counter: int = 0 + + def close(self) -> None: + """Close the connector and release resources.""" + # TODO: Implement proper resource clean up if needed. + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + world_rank = self.rank if role == "attention" else self.rank + attn_size + afd_pg = init_afd_process_group( + backend="nccl", + init_method=( + f"tcp://{self.config.afd_config.afd_host}" + f":{self.config.afd_config.afd_port}" + ), + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + + # Construct rank lists for sub groups. + # Each group contains one attention and one ffn rank. + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + assert len(ffn_ranks) == len(attn_ranks), ( + "ffn_ranks and attn_ranks must have the same length" + ) + default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = [attn_ranks[i], ffn_ranks[i]] + sub_group_ranks.append(ranks) + # Create two independent groups: + # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) + # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) + # The communication domain (rank range) is the same, but different group_name + # creates independent groups. + self.a2e_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="a2e", + ) + self.e2a_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="e2a", + ) + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def _build_tensor_metadata_list( + self, + tensor_metadata: TensorMetadata, + connector_metadata: AFDConnectorMetadata, + ) -> dict[int, TensorMetadata]: + tensor_metadata_list = {} + num_of_stages = connector_metadata.num_of_stages + for idx in range(num_of_stages): + if idx == 0: + tensor_metadata_list[0] = tensor_metadata + else: + new_size = list(tensor_metadata.size) + new_size[0] = connector_metadata.afd_tokens_lens[idx] + tensor_metadata_list[idx] = TensorMetadata( + tensor_metadata.device, + tensor_metadata.dtype, + torch.Size(new_size), + ) + return tensor_metadata_list + + def _send_metadata( + self, + metadata: AFDConnectorMetadata, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + + tensor_metadata = TensorMetadata( + hidden_states.device.type, hidden_states.dtype, hidden_states.size() + ) + metadata_tuple = (metadata, tensor_metadata) + process_group.send_object(metadata_tuple, dst=dst) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, metadata + ) + + def _recv_metadata( + self, + src: int, + process_group: GroupCoordinator, + ) -> None: + (self._current_afd_connector_metadata, tensor_metadata) = ( + process_group.recv_object(src=src) + ) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, self._current_afd_connector_metadata + ) + + def _send_hidden_states( + self, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + assert not hidden_states.is_cpu, "Hidden states must be on GPU" + torch.distributed.send( + hidden_states, + dst=process_group.ranks[dst], + group=process_group.device_group, + ) + + def _recv_hidden_states( + self, + src: int, + process_group: GroupCoordinator, + tensor_metadata: TensorMetadata, + ) -> tuple[torch.Tensor, list]: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return {}, [] + assert src < process_group.world_size, f"Invalid src rank ({src})" + + hidden_states = torch.empty( + tensor_metadata.size, + dtype=tensor_metadata.dtype, + device=tensor_metadata.device, + ) + torch.distributed.recv( + hidden_states, + src=process_group.ranks[src], + group=process_group.device_group, + ) + return hidden_states, [] + + # ------------------------------------------------------------------------- + # attn -> ffn + # ------------------------------------------------------------------------- + + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ) -> None: + """ + Called by ATTN side to send intermediate tensors + generated by ATTN instances to FFN. + """ + try: + dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size + if metadata.layer_idx == 0 and metadata.stage_idx == 0: + self._send_metadata(metadata, hidden_states, dst, self.a2e_group) + self._current_afd_connector_metadata = metadata + self._send_hidden_states(hidden_states, dst, self.a2e_group) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the ATTN side to receive MOE output intermediate tensors, + possibly dispatching from the receiver to other GPUs. + """ + src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size + stage_idx = ( + self.recv_ffn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.e2a_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self.recv_ffn_output_counter = ( + self.recv_ffn_output_counter + 1 + ) % self._current_afd_connector_metadata.num_of_stages + return hidden_states, self._current_afd_connector_metadata + + # ------------------------------------------------------------------------- + # ffn -> attn + # ------------------------------------------------------------------------- + + def send_ffn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """ + Called by FFN side to send intermediate tensors generated by FFN + instances back to the sender (should be the same GPU as source). + """ + dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size + self._send_hidden_states(hidden_states, dst, self.e2a_group) + self.recv_attn_output_counter += 1 + if ( + self.recv_attn_output_counter + % ( + self._current_afd_connector_metadata.num_of_stages + * self.num_hidden_layers + ) + == 0 + ): + self._need_recv_metadata = True + self.recv_attn_output_counter = 0 + + def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the FFN side to receive intermediate tensors from ATTN. + Handles receiving and possibly dispatching tensors. + """ + src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size + if self._need_recv_metadata: + self._recv_metadata(src, self.a2e_group) + self._need_recv_metadata = False + + stage_idx = ( + self.recv_attn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + layer_idx = ( + self.recv_attn_output_counter + // self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.a2e_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self._current_afd_connector_metadata.layer_idx = layer_idx + return hidden_states, self._current_afd_connector_metadata diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 338cb1f1814b5..3bda4f7b548a9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,6 +41,14 @@ import torch.distributed import torch.distributed._functional_collectives as funcol import torch.distributed._symmetric_memory from torch.distributed import Backend, ProcessGroup +from torch.distributed.distributed_c10d import ( + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -322,7 +330,6 @@ class GroupCoordinator: self_device_group = None self_cpu_group = None - for ranks in group_ranks: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend @@ -1029,6 +1036,63 @@ _INNER_DP_WORLD: GroupCoordinator | None = None _NODE_COUNT: int | None = None +def init_afd_process_group( + backend: str | Backend = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str = None, + pg_options: Any | None = None, +): + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store." + ) + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = PrefixStore(group_name, store) + + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + global _AFD + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + _AFD = pg + return pg + + +_WORLD: GroupCoordinator | None = None +_NODE_COUNT: int | None = None + + def get_world_group() -> GroupCoordinator: assert _WORLD is not None, "world group is not initialized" return _WORLD diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca19e468914c7..f911ba8d0ca8e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,6 +35,7 @@ import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( AttentionConfig, + AFDConfig, CacheConfig, CompilationConfig, ConfigType, @@ -73,7 +74,7 @@ from vllm.config.model import ( RunnerOption, TokenizerMode, ) -from vllm.config.multimodal import MMCacheType, MMEncoderTPMode +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.observability import DetailedTraceModules from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy from vllm.config.scheduler import SchedulerPolicy @@ -573,6 +574,8 @@ class EngineArgs: CacheConfig.kv_offloading_backend ) tokens_only: bool = False + # AFD config + afd_config: AFDConfig | None = None def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -1148,6 +1151,8 @@ class EngineArgs: "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] ) vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"]) + vllm_group.add_argument("--afd-config", **vllm_kwargs["afd_config"]) + vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) @@ -1721,6 +1726,7 @@ class EngineArgs: profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, + afd_config=self.afd_config, ) return config diff --git a/vllm/entrypoints/afd_ffn_server.py b/vllm/entrypoints/afd_ffn_server.py new file mode 100644 index 0000000000000..27aa6c7ddd344 --- /dev/null +++ b/vllm/entrypoints/afd_ffn_server.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM AFD FFN Server Entry Point + +This script provides a standalone entry point for running FFN servers in an AFD +(Attention-FFN Disaggregation) setup. FFN servers handle remote FFN computation +for attention workers. + +Usage: + python -m vllm.entrypoints.afd_ffn_server /path/to/model \ + --tensor-parallel-size 8 \ + --afd-config '{"afd_connector": "dummy", "afd_role": "ffn"}' \ +""" + +import threading +from typing import Any + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.logger import init_logger +from vllm.utils.argparse_utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +class AFDFFNServer: + """AFD FFN Server main class.""" + + def __init__(self, args: Any): + engine_args = AsyncEngineArgs.from_cli_args(args) + self.vllm_config = engine_args.create_engine_config() + logger.info("Start AFD FFN Server with vllm_config: %s", self.vllm_config) + + def start(self) -> None: + """Start the AFD FFN server.""" + try: + # Import here to avoid circular imports + from vllm.v1.executor.abstract import Executor + + # Create configurations + executor_class = Executor.get_class(self.vllm_config) + self.model_executor = executor_class(vllm_config=self.vllm_config) + # Start the FFN server loop + self._run_server_loop() + + except Exception as e: + logger.error("Failed to start AFD FFN server: %s", e) + raise + + def _run_server_loop(self) -> None: + """Start FFN workers and wait for completion""" + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise + + +def main(args: Any) -> None: + """Main entry point for AFD FFN server.""" + try: + server = AFDFFNServer(args) + server.start() + except KeyboardInterrupt: + logger.info("Interrupted by user") + except Exception as e: + logger.error("Server error: %s", e) + raise + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + # Add model as positional argument (like vllm serve) + parser.add_argument("model", type=str, help="Model name or path") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + # Set the model from positional argument + args.model = args.model + + main(args) diff --git a/vllm/entrypoints/cli/fserver.py b/vllm/entrypoints/cli/fserver.py new file mode 100644 index 0000000000000..087cc9d280d7f --- /dev/null +++ b/vllm/entrypoints/cli/fserver.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM AFD FFN Server CLI command.""" + +import argparse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.afd_ffn_server import main +from vllm.entrypoints.cli.types import CLISubcommand + + +class FServerCommand(CLISubcommand): + """Command for running vLLM AFD FFN Server.""" + + def __init__(self): + self.name = "fserver" + super().__init__() + + def subparser_init(self, subparsers): + """Initialize the fserver subparser.""" + parser = subparsers.add_parser( + self.name, + help="Start vLLM AFD FFN Server", + description="Start vLLM AFD FFN Server for Attention-FFN Disaggregation", + usage="vllm fserver MODEL --afd-config CONFIG [options]", + ) + + # Add model as positional argument (like vllm serve) + parser.add_argument("model", type=str, help="Model name or path") + + # Use AsyncEngineArgs to add all vLLM engine arguments + parser = AsyncEngineArgs.add_cli_args(parser) + + return parser + + def validate(self, args: argparse.Namespace) -> None: + """Validate arguments for fserver command.""" + # Validate that afd_config is provided + if not hasattr(args, "afd_config") or not args.afd_config: + raise ValueError("--afd-config is required for FFN server") + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the fserver command.""" + # Call the main function from afd_ffn_server directly with parsed args + main(args) + + +def cmd_init() -> list[CLISubcommand]: + """Initialize fserver command.""" + return [FServerCommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index a3e73eb7a4c9d..65bb51b2fd950 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -16,6 +16,7 @@ logger = init_logger(__name__) def main(): import vllm.entrypoints.cli.benchmark.main import vllm.entrypoints.cli.collect_env + import vllm.entrypoints.cli.fserver import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.run_batch import vllm.entrypoints.cli.serve @@ -25,6 +26,7 @@ def main(): CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, + vllm.entrypoints.cli.fserver, vllm.entrypoints.cli.benchmark.main, vllm.entrypoints.cli.collect_env, vllm.entrypoints.cli.run_batch, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 033cc1f544b3b..a715d70d1d5f4 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -12,6 +12,7 @@ import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig +from vllm.distributed.afd_transfer import AFDConnectorBase from vllm.logger import init_logger from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices @@ -94,6 +95,46 @@ class DPMetadata: # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: list[int] | None = None + @staticmethod + def num_stage_tokens_across_dp( + num_stage_tokens: list[int], dp_size: int, dp_rank: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather the stage token counts across all DP ranks. + Args: + num_stage_tokens: list of token counts per stage for current rank + dp_size: number of DP ranks + dp_rank: current DP rank + Returns: + stage_tokens_across_dp_cpu: [num_stages, dp_size] tensor + max_stage_tokens_across_dp_cpu: [num_stages] tensor with max + tokens per stage + """ + import torch.distributed as dist + + from vllm.distributed.parallel_state import get_dp_group + from vllm.platforms import current_platform + + device = current_platform.device_type + group = get_dp_group().device_group + + num_stages = len(num_stage_tokens) + stage_tokens_across_dp = torch.zeros( + (num_stages, dp_size), device=device, dtype=torch.int32 + ) + stage_tokens_across_dp[:, dp_rank] = torch.tensor( + num_stage_tokens, device=device, dtype=torch.int32 + ) + + # AllReduce to gather from all ranks + dist.all_reduce(stage_tokens_across_dp, group=group) + stage_tokens_across_dp_cpu = stage_tokens_across_dp.cpu() + + # Compute max tokens per stage + max_stage_tokens_across_dp_cpu = torch.max(stage_tokens_across_dp_cpu, dim=1)[0] + + return stage_tokens_across_dp_cpu, max_stage_tokens_across_dp_cpu + @staticmethod def make( parallel_config: ParallelConfig, @@ -182,6 +223,16 @@ class DPMetadata: return torch.cumsum(num_tokens_across_sp_cpu, dim=0) +@dataclass +class AFDMetadata: + afd_tokens_start_loc: list[int] + afd_reqs_start_loc: list[int] + afd_stage_idx: int + afd_connector: "AFDConnectorBase" + afd_tokens_lens: list[int] # padded lengths for tensor slicing + num_of_stages: int + + @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context @@ -198,6 +249,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: DPMetadata | None = None + afd_metadata: AFDMetadata | None = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE @@ -235,6 +287,7 @@ def create_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + afd_metadata: AFDMetadata | None = None, ): return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, @@ -244,6 +297,7 @@ def create_forward_context( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, + afd_metadata=afd_metadata, ) @@ -272,6 +326,7 @@ def set_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + afd_metadata: AFDMetadata | None = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -317,6 +372,7 @@ def set_forward_context( cudagraph_runtime_mode, batch_descriptor, ubatch_slices, + afd_metadata=afd_metadata, ) try: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 146124153c79d..cd3829b9220c9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -45,6 +45,7 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul @@ -101,6 +102,9 @@ elif current_platform.is_xpu(): logger = init_logger(__name__) +from vllm.forward_context import AFDMetadata +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield + class DeepseekAttention(nn.Module): """Normal MHA implementation used by Deepseek v1.""" @@ -1113,6 +1117,8 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config + afd_config = vllm_config.afd_config + self.afd_role = afd_config.afd_role if afd_config is not None else None self.hidden_size = config.hidden_size max_position_embeddings = getattr(config, "max_position_embeddings", 8192) moe_layer_freq = getattr(config, "moe_layer_freq", 1) @@ -1138,42 +1144,47 @@ class DeepseekV2DecoderLayer(nn.Module): attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - vllm_config=vllm_config, - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - v_head_dim=v_head_dim, - q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=kv_lora_rank, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - topk_indices_buffer=topk_indices_buffer, - ) - - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % moe_layer_freq == 0 - ): - self.mlp = DeepseekV2MoE( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = attn_cls( + vllm_config=vllm_config, config=config, - parallel_config=parallel_config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") + else None, + kv_lora_rank=kv_lora_rank, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, ) + + if self.afd_role is None or self.afd_role == "ffn": + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % moe_layer_freq == 0 + ): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -1217,6 +1228,9 @@ class DeepseekV2DecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: @@ -1229,6 +1243,49 @@ class DeepseekV2DecoderLayer(nn.Module): return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> torch.Tensor: # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1.0 / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + return hidden_states, residual + + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1.0 / self.routed_scaling_factor + return hidden_states + @support_torch_compile class DeepseekV2Model(nn.Module): @@ -1283,6 +1340,51 @@ class DeepseekV2Model(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + llama_4_scaling: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + recv_handle = None + for layer in islice(self.layers, self.start_layer, self.end_layer): + afd_connector = afd_metadata.afd_connector + afd_metadata.afd_stage_idx = dbo_current_ubatch_id() + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + + if recv_handle is not None: + for work in recv_handle: + work.wait() + current_hidden, residual = layer(positions, hidden_states, residual, llama_4_scaling) + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=afd_metadata.afd_stage_idx, + seq_len=current_hidden.shape[0], + dtype=current_hidden.dtype, + device=current_hidden.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(current_hidden, metadata) + + if dbo_enabled(): + dbo_yield() + + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + if recv_handle is not None: + for work in recv_handle: + work.wait() + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -1315,10 +1417,16 @@ class DeepseekV2Model(nn.Module): else: llama_4_scaling = None - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer( - positions, hidden_states, residual, llama_4_scaling + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata != None: + hidden_states, residual = self.forward_with_afd( + hidden_states, residual, positions, afd_metadata, llama_4_scaling ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual, llama_4_scaling) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1328,6 +1436,12 @@ class DeepseekV2Model(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, hidden_states, layer_idx + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class DeepseekV2MixtureOfExperts(MixtureOfExperts): moe_mlp_layers: list[DeepseekV2MoE] @@ -1394,6 +1508,10 @@ class DeepseekV2ForCausalLM( if self.use_mha: self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the @@ -1442,12 +1560,16 @@ class DeepseekV2ForCausalLM( continue assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): + if (self.afd_role is None or self.afd_role == "ffn") and isinstance( + layer.mlp, DeepseekV2MoE + ): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp self.moe_mlp_layers.append(layer.mlp) self.moe_layers.append(layer.mlp.experts) + if self.afd_role == "attention": + return self.extract_moe_parameters(example_moe) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -1465,6 +1587,12 @@ class DeepseekV2ForCausalLM( ) return hidden_states + def compute_ffn_output( + self, current_layer_idx, hidden_states + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, @@ -1508,6 +1636,13 @@ class DeepseekV2ForCausalLM( # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + if self.afd_role == "attention": + vllm_config = get_current_vllm_config() + num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts + ) + else: + num_redundant_experts = self.num_redundant_experts expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -1518,7 +1653,7 @@ class DeepseekV2ForCausalLM( if rocm_aiter_moe_shared_expert_enabled else 0 ), - num_redundant_experts=self.num_redundant_experts, + num_redundant_experts=num_redundant_experts, ) params_dict = dict(self.named_parameters()) @@ -1527,6 +1662,9 @@ class DeepseekV2ForCausalLM( if "rotary_emb.inv_freq" in name: continue + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model @@ -1627,7 +1765,8 @@ class DeepseekV2ForCausalLM( # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True - + if self.afd_role is not None and self.afd_role == "attention": + continue # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = chunk_name.replace(weight_name, param_name) @@ -1657,6 +1796,12 @@ class DeepseekV2ForCausalLM( loaded_params.add(name_mapped) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank @@ -1685,6 +1830,29 @@ class DeepseekV2ForCausalLM( return loaded_params + def is_moe_weight(self, name): + if ( + "shared_experts" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed_tokens" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + # or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache + return True + return False + class DeepseekForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py new file mode 100644 index 0000000000000..e269442e74bdc --- /dev/null +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.distributed.afd_transfer.afd_connector.factory import AFDConnectorFactory +from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_world_group, + graph_capture, +) +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + +logger = init_logger(__name__) + + +class GPUFFNModelRunner(LoRAModelRunnerMixin): + def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.dtype = self.model_config.dtype + self.load_config = vllm_config.load_config + + self.afd_config = vllm_config.afd_config + if not self.afd_config or not self.afd_config.is_ffn_server: + raise ValueError( + "AFD config must be provided with afd_role='ffn' for FFN server" + ) + + self._counter = 0 + + # Initialize torch.profile for performance monitoring + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 * 27 + 4000 * 27 * 2, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + "./profiler_logs/ffn" + ), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + + # Initialize CUDA graph support + self.use_cuda_graph = not self.model_config.enforce_eager + + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes) + ) + + # Storage for captured graphs + self._cuda_graphs: dict[ + tuple[int, int], torch.cuda.CUDAGraph + ] = {} # {(layer_idx, num_tokens): CUDAGraph} + self._graph_memory_pool = None + + assert self.afd_config.is_ffn_server + self.connector = AFDConnectorFactory.create_connector( + get_world_group().rank, get_world_group().local_rank, self.vllm_config + ) + + if getattr(self.model_config.hf_config, "text_config", None) is not None: + self.num_layers = self.model_config.hf_config.text_config.num_hidden_layers + else: + self.num_layers = self.model_config.hf_config.num_hidden_layers + + def get_model(self) -> nn.Module: + return self.model + + def initialize_afd_connector(self) -> None: + self.connector.init_afd_connector() + + def load_model(self, **kwargs) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: # noqa: SIM117 + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config + ) + else: + logger.info("Model was already initialized. Loading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) + + logger.info("AFD FFN Model loaded successfully") + + def _get_current_layer_idx(self) -> int: + return (self._counter // self.afd_config.num_afd_stages) % self.num_layers + + @torch.inference_mode() + def execute_model(self, scheduler_output=None, intermediate_tensors=None): + """Execute FFN computation for a single request""" + # scheduler_output and intermediate_tensors are unused in FFN server + # mode + self.profiler.step() + + try: + hidden_states, recv_metadata = self.connector.recv_attn_output() + current_layer_idx = recv_metadata.layer_idx + logger.info( + f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" + ) + num_tokens = hidden_states.shape[0] + if recv_metadata is not None and recv_metadata.recv_handle_list is not None: + for work in recv_metadata.recv_handle_list: + work.wait() + # Try to use CUDA graph if available + cuda_graph_info = self._find_cuda_graph(current_layer_idx, num_tokens) + if cuda_graph_info is not None: + # Use captured CUDA graph for computation + with set_forward_context( + attn_metadata=None, vllm_config=self.vllm_config + ): + rank_ffn_output = self._execute_with_cuda_graph( + hidden_states, cuda_graph_info + ) + else: + # Fallback to eager mode + with set_forward_context( + attn_metadata=None, vllm_config=self.vllm_config + ): + rank_ffn_output = self._execute_eager_mode( + hidden_states, current_layer_idx + ) + + recv_metadata.recv_handle_list = None + self.connector.send_ffn_output(rank_ffn_output, recv_metadata) + except Exception as e: + raise ValueError(f"Error computing FFN: {e}") from e + finally: + self._counter += 1 + if self._counter == self.num_layers * self.afd_config.num_afd_stages: + self._counter = 0 + return None # FFN server doesn't return ModelRunnerOutput + + def _execute_with_cuda_graph( + self, hidden_states: torch.Tensor, cuda_graph_info: dict + ): + """Execute FFN computation using captured CUDA graph.""" + graph = cuda_graph_info["graph"] + input_tensor = cuda_graph_info["input_hidden_states"] + output_tensor = cuda_graph_info["output"] + + # Copy input data to graph's input tensor + # Handle padding if necessary + actual_tokens = hidden_states.shape[0] + graph_tokens = input_tensor.shape[0] + + if actual_tokens <= graph_tokens: + # Copy actual data and pad with zeros if needed + input_tensor[:actual_tokens].copy_(hidden_states) + if actual_tokens < graph_tokens: + input_tensor[actual_tokens:].zero_() + else: + raise ValueError( + f"Input size {actual_tokens} exceeds graph capacity {graph_tokens}" + ) + + # Replay the captured graph + graph.replay() + + # Return only the actual output (without padding) + return output_tensor[:actual_tokens].clone() + + def _execute_eager_mode( + self, + hidden_states: torch.Tensor, + current_layer_idx: int, + recv_metadata: AFDConnectorMetadata = None, + ): + """Execute FFN computation in eager mode (fallback).""" + # Step the profiler for performance monitoring + + # Handle TP case: all-gather tensors from all TP ranks + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size > 1: + # All-gather hidden states from all TP ranks + gathered_hidden_states = tensor_model_parallel_all_gather( + hidden_states, dim=0 + ) + ffn_output = self.model.compute_ffn_output( + current_layer_idx, gathered_hidden_states + ) + # Extract the output corresponding to current rank + start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() + end_idx = start_idx + hidden_states.shape[0] + rank_ffn_output = ffn_output[start_idx:end_idx, :] + else: + # Single TP case + rank_ffn_output = self.model.compute_ffn_output( + current_layer_idx, hidden_states + ) + + return rank_ffn_output + + # Methods required for interface compatibility with GPUModelRunner + def profile_run(self) -> None: + """FFN servers don't need profiling.""" + pass + + def get_kv_cache_spec(self) -> dict[str, "KVCacheSpec"]: + """FFN servers don't use KV cache.""" + return {} + + def initialize_kv_cache(self, kv_cache_config: "KVCacheConfig") -> None: + """FFN servers don't use KV cache.""" + pass + + def _dummy_run(self, num_tokens: int = 1, **kwargs) -> torch.Tensor: + """FFN servers don't need dummy runs.""" + # Return a dummy tensor for interface compatibility + return torch.zeros( + num_tokens, + self.model_config.hf_config.hidden_size, + dtype=self.dtype, + device=self.device, + ) + + def capture_model(self) -> int: + """Capture CUDA graphs for FFN operations.""" + if not self.use_cuda_graph: + logger.warning("Skipping CUDA graph capture.") + return 0 + + logger.info("Starting CUDA graph capture for FFN operations...") + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Create memory pool for graphs + if self._graph_memory_pool is None: + self._graph_memory_pool = torch.cuda.graph_pool_handle() + + # Capture graphs for each layer and different batch sizes + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for layer_idx in range(self.num_layers): + for num_tokens in reversed(self.cudagraph_batch_sizes): + with set_forward_context( + attn_metadata=None, vllm_config=self.vllm_config + ): + self._capture_graph_for_layer_and_size(layer_idx, num_tokens) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + + logger.info( + "FFN CUDA graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) + return cuda_graph_size + + def _capture_graph_for_layer_and_size(self, layer_idx: int, num_tokens: int): + """Capture CUDA graph for specific layer and number of tokens.""" + # Create dummy hidden states + dummy_hidden_states = torch.randn( + num_tokens, + self.model_config.hf_config.hidden_size, + dtype=self.dtype, + device=self.device, + ) + + # Warm up the operations for this specific layer + for _ in range(self.vllm_config.compilation_config.cudagraph_num_of_warmups): + self._run_ffn_computation( + dummy_hidden_states, layer_idx=layer_idx, capture_mode=True + ) + + # Create and capture the graph + graph = torch.cuda.CUDAGraph() + + # Start graph capture + with torch.cuda.graph(graph, pool=self._graph_memory_pool): + output = self._run_ffn_computation( + dummy_hidden_states, layer_idx=layer_idx, capture_mode=True + ) + + # Store the captured graph with layer and token count as key + self._cuda_graphs[(layer_idx, num_tokens)] = { + "graph": graph, + "input_hidden_states": dummy_hidden_states, + "output": output, + } + + logger.debug( + "Captured CUDA graph for layer %s with %s tokens", layer_idx, num_tokens + ) + + def _run_ffn_computation( + self, + hidden_states: torch.Tensor, + layer_idx: int | None = None, + capture_mode: bool = False, + ): + """Run FFN computation for graph capture or replay.""" + if layer_idx is None: + current_layer_idx = self._get_current_layer_idx() if not capture_mode else 0 + else: + current_layer_idx = layer_idx + + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size > 1: + # Handle TP case: all-gather tensors from all TP ranks + gathered_hidden_states = tensor_model_parallel_all_gather( + hidden_states, dim=0 + ) + ffn_output = self.model.compute_ffn_output( + current_layer_idx, gathered_hidden_states + ) + + # Extract the output corresponding to current rank + start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() + end_idx = start_idx + hidden_states.shape[0] + rank_ffn_output = ffn_output[start_idx:end_idx, :] + else: + # Single TP case + rank_ffn_output = self.model.compute_ffn_output( + current_layer_idx, hidden_states + ) + + return rank_ffn_output + + def _find_cuda_graph(self, layer_idx: int, num_tokens: int): + """Find the smallest graph that can handle the given layer and + number of tokens.""" + if not self.use_cuda_graph: + return None + + # Find the minimum capture size that can handle num_tokens for this + # layer + for capture_size in self.cudagraph_batch_sizes: + if num_tokens <= capture_size: + return self._cuda_graphs.get((layer_idx, capture_size)) + return None + + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + """FFN servers don't use samplers.""" + pass + + def update_config(self, overrides: dict[str, Any]) -> None: + """Update configuration for FFN model runner.""" + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " + f"Allowed configs: {allowed_config_names}" + ) + config = getattr(self, config_name) + from vllm.config import update_config + + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + + def reload_weights(self) -> None: + """Reload model weights for FFN model runner.""" + assert getattr(self, "model", None) is not None, ( + "Cannot reload weights before model is loaded." + ) + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) + + @property + def lora_config(self): + """FFN servers don't support LoRA.""" + return None + + @property + def is_pooling_model(self) -> bool: + """FFN servers are not pooling models.""" + return False + + def _dummy_pooler_run(self, hidden_states: torch.Tensor): + """FFN servers don't have poolers.""" + pass + + def get_supported_tasks(self): + """Get supported tasks for FFN model runner.""" + return [] + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + """Get number of input tokens for FFN model runner.""" + return num_scheduled_tokens + + def take_draft_token_ids(self, **kwargs): + """FFN servers don't support draft tokens.""" + pass + + @property + def eplb_state(self): + """FFN servers don't have EPLB state.""" + return None + + def ensure_kv_transfer_shutdown(self): + """FFN servers don't need KV transfer shutdown.""" + pass + + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + """FFN servers don't support tensorized model saving.""" + raise NotImplementedError("FFN servers don't support tensorized model saving") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1aa2ec6bb655c..82dd80d267182 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -37,6 +37,7 @@ from vllm.config import ( get_layers_from_vllm_config, update_config, ) +from vllm.distributed.afd_transfer import AFDConnectorFactory from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group @@ -45,11 +46,13 @@ from vllm.distributed.parallel_state import ( get_dcp_group, get_pp_group, get_tp_group, + get_world_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model, ) from vllm.forward_context import ( + AFDMetadata, BatchDescriptor, set_forward_context, ) @@ -551,6 +554,16 @@ class GPUModelRunner( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + + # init AFD config + self.afd_config = vllm_config.afd_config + if self.afd_config and self.afd_config.afd_role == "attention": + self.afd_connector = AFDConnectorFactory.create_connector( + get_world_group().rank, get_world_group().local_rank, vllm_config + ) + self.afd_connector.init_afd_connector() + self.num_stages = self.afd_config.num_afd_stages + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() self.kv_sharing_fast_prefill_logits_indices = None @@ -610,6 +623,25 @@ class GPUModelRunner( self.kv_connector_output: KVConnectorOutput | None = None self.layerwise_nvtx_hooks_registered = False + profile_dir = ( + "./profiler_logs/attn" + if self.afd_config is not None and self.afd_config.afd_role == "attention" + else "./profiler_logs/normal" + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 + 4000, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -1313,7 +1345,7 @@ class GPUModelRunner( assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - + logger.info(f"nums_reqs: {num_reqs}") # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) @@ -2823,6 +2855,7 @@ class GPUModelRunner( # decoder. allow_dp_padding = ( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + or self.afd_config is not None ) should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( @@ -2904,6 +2937,38 @@ class GPUModelRunner( pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True + def _build_afd_metadata( + self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int + ): + afd_metadata = None + if self.afd_config: + # For prefill, compute tokens per stage based on actual token + # counts + afd_tokens_start_loc = [0] + afd_tokens_lens = [] + if ubatch_slices and len(ubatch_slices) > 1: + afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] + afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] + logger.info( + f"afd_tokens_start_loc: {afd_tokens_start_loc} " + f"afd_reqs_start_loc: {afd_reqs_start_loc} " + f"ubatch_slices: {ubatch_slices}" + ) + afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] + else: + afd_tokens_start_loc = [0] + afd_reqs_start_loc = [0] + afd_tokens_lens = [num_tokens_unpadded] + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + num_of_stages=len(ubatch_slices) if ubatch_slices else 1, + ) + return afd_metadata + @torch.inference_mode() def execute_model( self, @@ -3076,6 +3141,9 @@ class GPUModelRunner( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + + self.profiler.step() # Run the model. # Use persistent buffers for CUDA graphs. with ( @@ -3087,10 +3155,14 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + afd_metadata=afd_metadata, ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): + logger.info(f"input_ids: {input_ids.shape}") + if inputs_embeds: + logger.info(f"inputs_embeds: {inputs_embeds.shape}") model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -4202,6 +4274,8 @@ class GPUModelRunner( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded + afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( @@ -4212,6 +4286,7 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + afd_metadata=afd_metadata, ), ): outputs = self.model( @@ -4761,7 +4836,6 @@ class GPUModelRunner( kv_cache_spec, kv_cache_group_id, ) - attn_groups.append(attn_group) return attn_groups @@ -5438,6 +5512,11 @@ class GPUModelRunner( kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + def initialize_afd_connector(self) -> None: + """Initialize AFD connector if available.""" + if hasattr(self, "afd_connector") and self.afd_connector: + self.afd_connector.init_afd_connector() + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index af09129e67b1e..21694b367c326 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -126,7 +126,10 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None - if vllm_config.parallel_config.enable_expert_parallel: + if ( + vllm_config.parallel_config.enable_expert_parallel + and not vllm_config.afd_config + ): # Currently only DeepEP highthroughput supports SM control so this # only affects that case. ep_group = get_ep_group() @@ -303,6 +306,7 @@ class UBatchWrapper: dp_metadata, batch_descriptor, cudagraph_runtime_mode, + afd_metadata, ) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] @@ -314,6 +318,7 @@ class UBatchWrapper: dp_metadata=dp_metadata[i], batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode, + afd_metadata=afd_metadata, ) ) @@ -385,6 +390,7 @@ class UBatchWrapper: batch_descriptor = forward_context.batch_descriptor ubatch_slices = forward_context.ubatch_slices cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + afd_metadata = forward_context.afd_metadata # If there's no ubatching, just run the runnable object if ubatch_slices is None: @@ -448,6 +454,7 @@ class UBatchWrapper: dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, + afd_metadata=afd_metadata, ) with self.sm_control: return self._capture_ubatches(ubatch_metadata, self.model) @@ -470,6 +477,7 @@ class UBatchWrapper: dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, + afd_metadata=afd_metadata, ) with self.sm_control: return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 1e13650cd083e..fef91b3413d5e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,7 +6,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext from types import NoneType -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Optional, cast import numpy as np import torch @@ -52,6 +52,8 @@ from vllm.v1.outputs import ( ModelRunnerOutput, ) from vllm.v1.utils import report_usage_stats +from vllm.v1.worker.gpu_ffn_model_runner import GPUFFNModelRunner +from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager @@ -260,8 +262,11 @@ class Worker(WorkerBase): num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 init_workspace_manager(self.device, num_ubatches) + self.model_runner: GPUModelRunner | GPUFFNModelRunner | GPUModelRunnerV2 + if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server: + self.model_runner = GPUFFNModelRunner(self.vllm_config, self.device) # Construct the model runner - if self.use_v2_model_runner: + elif self.use_v2_model_runner: from vllm.v1.worker.gpu.model_runner import ( GPUModelRunner as GPUModelRunnerV2, ) @@ -573,8 +578,16 @@ class Worker(WorkerBase): @torch.inference_mode() def execute_model( - self, scheduler_output: "SchedulerOutput" - ) -> ModelRunnerOutput | None: + self, scheduler_output: Optional["SchedulerOutput"] = None + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: + # FFN server mode: direct execution without pipeline parallelism + if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server: + return self.model_runner.execute_model(scheduler_output) + + if scheduler_output is None: + raise ValueError("scheduler_output is required in normal inference mode") + + # Normal inference mode intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -676,6 +689,53 @@ class Worker(WorkerBase): # worker will always be healthy as long as it's running. return + def start_ffn_server_loop(self) -> None: + """Start FFN server loop for AFD FFN workers""" + if not ( + self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server + ): + return + + self.model_runner.capture_model() + self.model_runner.initialize_afd_connector() + + if self.profiler: + self.profiler.start() + for _ in range(1000): # FIXME: hardcoded profiler iterations + self.model_runner.execute_model(scheduler_output=None) + torch.cuda.synchronize() # Ensure GPU operations complete + self.profiler.stop() + print(self.profiler.key_averages().table(sort_by="self_cuda_time_total")) + + import threading + + self._ffn_shutdown_event = threading.Event() + + def ffn_worker_loop(): + # Set CUDA device for this thread (thread-local context) + torch.cuda.set_device(self.device) + logger.info("FFN worker loop started") + + try: + while not self._ffn_shutdown_event.is_set(): + # Execute FFN computation + self.model_runner.execute_model(scheduler_output=None) + except Exception as e: + logger.error("FFN worker loop error: %s", e) + raise + + self._ffn_thread = threading.Thread(target=ffn_worker_loop, daemon=True) + self._ffn_thread.start() + logger.info("FFN server loop started in worker") + + def stop_ffn_server_loop(self) -> None: + """Stop FFN server loop""" + if hasattr(self, "_ffn_shutdown_event"): + self._ffn_shutdown_event.set() + if hasattr(self, "_ffn_thread"): + self._ffn_thread.join(timeout=5) + logger.info("FFN server loop stopped") + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: from vllm.distributed.parallel_state import get_ep_group From 28cba040c791d926d9337ea1b687717eba4761c2 Mon Sep 17 00:00:00 2001 From: jiangkuaixue123 Date: Mon, 15 Dec 2025 14:38:03 +0800 Subject: [PATCH 02/19] afd use ubatch without thread Signed-off-by: jiangkuaixue123 --- vllm/forward_context.py | 10 ++- vllm/model_executor/models/deepseek_v2.py | 100 +++++++++++++++++++++- vllm/v1/worker/gpu_ubatch_wrapper.py | 91 ++++++++++++++++++-- 3 files changed, 188 insertions(+), 13 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a715d70d1d5f4..a2f365cf21eb3 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, NamedTuple import torch @@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.distributed.afd_transfer import AFDConnectorBase from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices @@ -232,6 +233,13 @@ class AFDMetadata: afd_tokens_lens: list[int] # padded lengths for tensor slicing num_of_stages: int + input_ids_list: list[torch.Tensor] = field(default_factory=list) + positions_list: list[torch.Tensor] = field(default_factory=list) + inputs_embeds_list: list[torch.Tensor] = field(default_factory=list) + intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list) + attn_metadata_list: list[AttentionMetadata] = field(default_factory=list) + dp_metadata_list: list[DPMetadata] = field(default_factory=list) + @dataclass class ForwardContext: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cd3829b9220c9..eb9d85cfc55fd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1361,7 +1361,9 @@ class DeepseekV2Model(nn.Module): if recv_handle is not None: for work in recv_handle: work.wait() - current_hidden, residual = layer(positions, hidden_states, residual, llama_4_scaling) + current_hidden, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=afd_metadata.afd_stage_idx, @@ -1385,6 +1387,96 @@ class DeepseekV2Model(nn.Module): return hidden_states, residual + def forward_with_afd_v2( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + llama_4_scaling: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() + recv_handle = None + + ubatch_hidden_states = [] + ubatch_residual = [] + + start_idx = 0 + for pos in afd_metadata.positions_list: + # DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2 + # Otherwise use shape[0] as requested + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None + ) + start_idx = end_idx + + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + logger.info( + f"jcz deepseekv2 forward_with_afd_v2 layer_idx: {layer.layer_idx}, stage_i: {stage_i}" + ) + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] + + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + else: + hidden_states = ubatch_hidden_states[stage_i] + + if recv_handle is not None: + for work in recv_handle: + work.wait() + + current_positions = afd_metadata.positions_list[stage_i] + logger.info( + f"jcz deepseekv2 forward_with_afd_v2 hidden_states: {hidden_states.shape}" + f" positions:{positions.shape}" + ) + hidden_states, residual = layer( + current_positions, hidden_states, residual, llama_4_scaling + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer and last stage FFN output. + ubatch_hidden_states[afd_metadata.num_of_stages - 1], recv_metadata = ( + afd_connector.recv_ffn_output() + ) + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + if recv_handle is not None: + for work in recv_handle: + work.wait() + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -1421,12 +1513,14 @@ class DeepseekV2Model(nn.Module): afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None if afd_metadata != None: - hidden_states, residual = self.forward_with_afd( + hidden_states, residual = self.forward_with_afd_v2( hidden_states, residual, positions, afd_metadata, llama_4_scaling ) else: for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual, llama_4_scaling) + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) if not get_pp_group().is_last_rank: return IntermediateTensors( diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 21694b367c326..9e17c718c5513 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import ( + AFDMetadata, DPMetadata, create_forward_context, get_forward_context, @@ -358,6 +359,59 @@ class UBatchWrapper: return ubatch_metadata + def _make_afd_ubatch_metadata( + self, + ubatch_slices, + attn_metadata, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + dp_metadata, + afd_metadata, + ) -> AFDMetadata: + if ubatch_slices is None: + afd_metadata.input_ids_list.append(input_ids) + afd_metadata.positions_list.append(positions) + afd_metadata.inputs_embeds_list.append(inputs_embeds) + afd_metadata.intermediate_tensors_list.append(intermediate_tensors) + afd_metadata.attn_metadata_list.append(attn_metadata) + afd_metadata.dp_metadata_list.append(dp_metadata) + else: + for i, ubatch_slice in enumerate(ubatch_slices): + ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) = self._slice_model_inputs( + ubatch_slice.token_slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ) + + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata = DPMetadata.make( + self.vllm_config.parallel_config, + ubatch_slice.num_tokens, + ubatch_num_tokens_across_dp, + ) + + afd_metadata.input_ids_list.append(sliced_input_ids) + afd_metadata.positions_list.append(sliced_positions) + afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) + afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors) + afd_metadata.attn_metadata_list.append( + attn_metadata[i] if attn_metadata is not None else None) + afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) + + return afd_metadata + def _slice_model_inputs( self, tokens_slice: slice, @@ -392,6 +446,33 @@ class UBatchWrapper: cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode afd_metadata = forward_context.afd_metadata + attn_metadata = forward_context.attn_metadata + input_ids = kwargs["input_ids"] + positions = kwargs["positions"] + intermediate_tensors = kwargs["intermediate_tensors"] + inputs_embeds = kwargs["inputs_embeds"] + compute_stream = torch.cuda.current_stream() + + dp_metadata = forward_context.dp_metadata + + if self.vllm_config.afd_config: + afd_metadata = self._make_afd_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + dp_metadata=dp_metadata, + afd_metadata=afd_metadata, + ) + forward_context.afd_metadata = afd_metadata + if cudagraph_runtime_mode is CUDAGraphMode.NONE: + return self.runnable(*args, **kwargs) + else: + assert self.cudagraph_wrapper is not None + return self.cudagraph_wrapper(*args, **kwargs) + # If there's no ubatching, just run the runnable object if ubatch_slices is None: # This is to account for the case where ubatching was aborted. @@ -400,6 +481,7 @@ class UBatchWrapper: # num_tokens, we don't have a non-ubatched one. Without this # check, the cudagraph wrapper will try to capture a cudagraph # for this shape during a normal run. + if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None if batch_descriptor.num_tokens in self.cudagraphs: @@ -411,18 +493,9 @@ class UBatchWrapper: assert self.cudagraph_wrapper is not None return self.cudagraph_wrapper(*args, **kwargs) - attn_metadata = forward_context.attn_metadata num_tokens = ( ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start ) * 2 - input_ids = kwargs["input_ids"] - positions = kwargs["positions"] - intermediate_tensors = kwargs["intermediate_tensors"] - inputs_embeds = kwargs["inputs_embeds"] - compute_stream = torch.cuda.current_stream() - - dp_metadata = forward_context.dp_metadata - # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None ubatch_dp_metadata = [] From eb2355c600dfb2fe7fca1ae9d8c8f2f158878820 Mon Sep 17 00:00:00 2001 From: jiangkuaixue123 Date: Sat, 13 Dec 2025 10:36:03 +0800 Subject: [PATCH 03/19] ffn server use vllm serve and dp Signed-off-by: jiangkuaixue123 --- vllm/config/parallel.py | 1 - vllm/entrypoints/cli/serve.py | 1 - vllm/entrypoints/openai/api_server.py | 1 - vllm/v1/engine/core.py | 41 +++++++++++++++++++++++++- vllm/v1/engine/utils.py | 11 ++++++- vllm/v1/executor/multiproc_executor.py | 6 ++++ 6 files changed, 56 insertions(+), 5 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 3fe066ec32505..49328aa2818ca 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -557,7 +557,6 @@ class ParallelConfig: if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. - from vllm.v1.executor import ray_utils backend: DistributedExecutorBackend = "mp" diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 96608f360e17b..a7eaa5e8be749 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -191,7 +191,6 @@ def run_multi_api_server(args: argparse.Namespace): assert external_dp_lb or hybrid_dp_lb or dp_rank == 0 api_server_manager: APIServerProcessManager | None = None - with launch_core_engines( vllm_config, executor_class, log_stats, num_api_servers ) as (local_engine_manager, coordinator, addresses): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 5d0eacae34dd7..f5c0fc82b3c21 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1402,7 +1402,6 @@ async def run_server_worker( listen_address, sock, args, client_config=None, **uvicorn_kwargs ) -> None: """Run a single API server worker.""" - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0045b8c1dd3e7..257e097401b10 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -103,6 +103,11 @@ class EngineCore: if executor_fail_callback is not None: self.model_executor.register_failure_callback(executor_fail_callback) + self.afd_config = vllm_config.afd_config + if self.afd_config and self.afd_config.afd_role == "ffn": + logger.info("jcz EngineCore ffn role") + return + self.available_gpu_memory_for_kv_cache = -1 # Setup KV Caches and update CacheConfig after profiling. @@ -601,6 +606,7 @@ class EngineCoreProc(EngineCore): executor_fail_callback = lambda: self.input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b"") ) + self.afd_config = vllm_config.afd_config self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") @@ -855,7 +861,6 @@ class EngineCoreProc(EngineCore): set_process_title("EngineCore") decorate_logs() engine_core = EngineCoreProc(*args, **kwargs) - engine_core.run_busy_loop() except SystemExit: @@ -878,6 +883,23 @@ class EngineCoreProc(EngineCore): def run_busy_loop(self): """Core busy loop of the EngineCore.""" + if self.afd_config and self.afd_config.afd_role == "ffn": + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise + # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. @@ -1156,6 +1178,7 @@ class DPEngineCoreProc(EngineCoreProc): # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank + self.afd_config = vllm_config.afd_config super().__init__( vllm_config, local_client, @@ -1238,6 +1261,22 @@ class DPEngineCoreProc(EngineCoreProc): def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" + if self.afd_config and self.afd_config.afd_role == "ffn": + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise # Loop until process is sent a SIGINT or SIGTERM while True: diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 24bf66c42f312..ec0805102d2ac 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -16,7 +16,7 @@ import msgspec import zmq from vllm import envs -from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy @@ -908,6 +908,7 @@ def launch_core_engines( vllm_config.cache_config, local_engine_manager, coordinator.proc if coordinator else None, + vllm_config.afd_config, ) @@ -919,6 +920,7 @@ def wait_for_engine_startup( cache_config: CacheConfig, proc_manager: CoreEngineProcManager | None, coord_process: Process | None, + afd_config: AFDConfig | None = None, ): # Wait for engine core process(es) to send ready messages. local_count = parallel_config.data_parallel_size_local @@ -1020,6 +1022,13 @@ def wait_for_engine_startup( conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 engine.state = CoreEngineState.CONNECTED + elif ( + status == "READY" + and engine.state == CoreEngineState.CONNECTED + and afd_config + and afd_config.afd_role == "ffn" + ): + engine.state = CoreEngineState.READY elif status == "READY" and engine.state == CoreEngineState.CONNECTED: # Setup KV cache config with initialization state from # engine core process. Sum values from all engines in DP case. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 649875fe8b7c1..c112980a95a3e 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -213,6 +213,9 @@ class MultiprocExecutor(Executor): self.output_rank = self._get_output_rank() + self.afd_config = self.vllm_config.afd_config + self.afd_role = self.afd_config.afd_role if self.afd_config else None + def start_worker_monitor(self, inline=False) -> None: workers = self.workers self_ref = weakref.ref(self) @@ -565,6 +568,9 @@ class WorkerProc: # environment variable overrides after this point) enable_envs_cache() + self.afd_config = vllm_config.afd_config + self.afd_role = self.afd_config.afd_role if self.afd_config else None + @staticmethod def make_worker_process( vllm_config: VllmConfig, From 00570c9fac2282b56b7880ed8866c27b6a1949d6 Mon Sep 17 00:00:00 2001 From: jiangkuaixue123 Date: Mon, 15 Dec 2025 16:15:28 +0800 Subject: [PATCH 04/19] ffn dp use all2all Signed-off-by: jiangkuaixue123 --- .../afd_transfer/afd_connector/p2p_connector.py | 16 ++++++++++++++++ vllm/v1/worker/gpu_ffn_model_runner.py | 8 +++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index d9d9730b55365..0655db693bce4 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -15,6 +15,10 @@ from vllm.distributed.parallel_state import ( init_model_parallel_group, ) from vllm.logger import init_logger +from vllm.forward_context import ( + DPMetadata, + get_forward_context, +) from .base import AFDConnectorBase from .metadata import AFDConnectorMetadata @@ -53,6 +57,7 @@ class P2PAFDConnector(AFDConnectorBase): ) self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 + self.dp_metadata_list: dict[int, DPMetadata] = {} def close(self) -> None: """Close the connector and release resources.""" @@ -169,6 +174,17 @@ class P2PAFDConnector(AFDConnectorBase): self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) + if self.config.parallel_config.data_parallel_size > 1: + logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): + num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] + self.dp_metadata_list[stage_idx] = DPMetadata.make( + self.config.parallel_config, + num_tokens_per_ubatch, + torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, + device="cpu", dtype=torch.int32), + ) + logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) def _send_hidden_states( self, diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index e269442e74bdc..51dd62255acc7 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import ( get_world_group, graph_capture, ) -from vllm.forward_context import set_forward_context +from vllm.forward_context import set_forward_context, get_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes @@ -130,6 +130,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): try: hidden_states, recv_metadata = self.connector.recv_attn_output() + if hasattr(self.connector, 'dp_metadata_list'): + dp_metadata = self.connector.dp_metadata_list.get(recv_metadata.stage_idx, None) + else: + dp_metadata = None current_layer_idx = recv_metadata.layer_idx logger.info( f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" @@ -145,6 +149,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): with set_forward_context( attn_metadata=None, vllm_config=self.vllm_config ): + get_forward_context().dp_metadata = dp_metadata rank_ffn_output = self._execute_with_cuda_graph( hidden_states, cuda_graph_info ) @@ -153,6 +158,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): with set_forward_context( attn_metadata=None, vllm_config=self.vllm_config ): + get_forward_context().dp_metadata = dp_metadata rank_ffn_output = self._execute_eager_mode( hidden_states, current_layer_idx ) From 36f9c3d6b5b7ade536ad38c23daa33dba9131f92 Mon Sep 17 00:00:00 2001 From: jiangkuaixue123 Date: Mon, 15 Dec 2025 16:25:45 +0800 Subject: [PATCH 05/19] add log Signed-off-by: jiangkuaixue123 --- vllm/config/parallel.py | 1 + .../afd_connector/p2p_connector.py | 1 + vllm/model_executor/layers/linear.py | 1 - vllm/model_executor/models/deepseek_v2.py | 22 +++++-------------- vllm/v1/engine/core.py | 1 - vllm/v1/worker/gpu_ffn_model_runner.py | 1 + 6 files changed, 8 insertions(+), 19 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 49328aa2818ca..e8a69f29f999f 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -557,6 +557,7 @@ class ParallelConfig: if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. + from vllm.v1.executor import ray_utils backend: DistributedExecutorBackend = "mp" diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 0655db693bce4..6dc1f317b87d5 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -317,4 +317,5 @@ class P2PAFDConnector(AFDConnectorBase): ) self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx + self._current_afd_connector_metadata.stage_idx = stage_idx return hidden_states, self._current_afd_connector_metadata diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dfcc601a1c530..f4bc875cf10b3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -563,7 +563,6 @@ class ColumnParallelLinear(LinearBase): # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index eb9d85cfc55fd..7938cff98c354 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -387,7 +387,6 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states ) - return final_hidden_states.view(num_tokens, hidden_dim) @@ -1415,9 +1414,6 @@ class DeepseekV2Model(nn.Module): for layer in islice(self.layers, self.start_layer, self.end_layer): for stage_i in range(forward_conext.afd_metadata.num_of_stages): - logger.info( - f"jcz deepseekv2 forward_with_afd_v2 layer_idx: {layer.layer_idx}, stage_i: {stage_i}" - ) afd_connector = afd_metadata.afd_connector forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] @@ -1436,10 +1432,6 @@ class DeepseekV2Model(nn.Module): work.wait() current_positions = afd_metadata.positions_list[stage_i] - logger.info( - f"jcz deepseekv2 forward_with_afd_v2 hidden_states: {hidden_states.shape}" - f" positions:{positions.shape}" - ) hidden_states, residual = layer( current_positions, hidden_states, residual, llama_4_scaling ) @@ -1458,15 +1450,11 @@ class DeepseekV2Model(nn.Module): ) afd_connector.send_attn_output(hidden_states, metadata) - # Recv last layer and last stage FFN output. - ubatch_hidden_states[afd_metadata.num_of_stages - 1], recv_metadata = ( - afd_connector.recv_ffn_output() - ) - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list - if recv_handle is not None: - for work in recv_handle: - work.wait() + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i], recv_metadata = ( + afd_connector.recv_ffn_output() + ) # Re-assemble the batch hidden_states = torch.cat(ubatch_hidden_states, dim=0) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 257e097401b10..41b617843ac11 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -105,7 +105,6 @@ class EngineCore: self.afd_config = vllm_config.afd_config if self.afd_config and self.afd_config.afd_role == "ffn": - logger.info("jcz EngineCore ffn role") return self.available_gpu_memory_for_kv_cache = -1 diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index 51dd62255acc7..cb08c9c05ae58 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -137,6 +137,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): current_layer_idx = recv_metadata.layer_idx logger.info( f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" + f" dp_metadata: {dp_metadata}" ) num_tokens = hidden_states.shape[0] if recv_metadata is not None and recv_metadata.recv_handle_list is not None: From d306d01dd776ca19744c83a0875fb05a0da1bace Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 14:30:55 +0800 Subject: [PATCH 06/19] [Feat] adapt step3 text model --- .../afd_connector/p2p_connector.py | 310 ++++++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/step3_text.py | 242 +++++++++++--- vllm/model_executor/models/step3_vl.py | 10 + vllm/v1/worker/gpu_model_runner.py | 56 ++++ vllm/v1/worker/gpu_ubatch_wrapper.py | 4 + 6 files changed, 580 insertions(+), 44 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 6dc1f317b87d5..3c3359fae96f5 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -7,6 +7,316 @@ from datetime import timedelta import torch from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + init_afd_process_group, + init_model_parallel_group, +) +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self.config = config + self._initialized: bool = False + self._need_recv_metadata: bool = True + self._tensor_metadata_list: dict[int, TensorMetadata] = {} + self._current_afd_connector_metadata: AFDConnectorMetadata | None = None + if getattr(self.config.model_config.hf_config, "text_config", None) is not None: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.text_config.num_hidden_layers + ) + else: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.num_hidden_layers + ) + + self.recv_attn_output_counter: int = 0 + self.recv_ffn_output_counter: int = 0 + + def close(self) -> None: + """Close the connector and release resources.""" + # TODO: Implement proper resource clean up if needed. + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + world_rank = self.rank if role == "attention" else self.rank + attn_size + afd_pg = init_afd_process_group( + backend="nccl", + init_method=( + f"tcp://{self.config.afd_config.afd_host}" + f":{self.config.afd_config.afd_port}" + ), + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + + # Construct rank lists for sub groups. + # Each group contains one attention and one ffn rank. + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + assert len(ffn_ranks) == len(attn_ranks), ( + "ffn_ranks and attn_ranks must have the same length" + ) + default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = [attn_ranks[i], ffn_ranks[i]] + sub_group_ranks.append(ranks) + # Create two independent groups: + # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) + # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) + # The communication domain (rank range) is the same, but different group_name + # creates independent groups. + self.a2e_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="a2e", + ) + self.e2a_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="e2a", + ) + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def _build_tensor_metadata_list( + self, + tensor_metadata: TensorMetadata, + connector_metadata: AFDConnectorMetadata, + ) -> dict[int, TensorMetadata]: + tensor_metadata_list = {} + num_of_stages = connector_metadata.num_of_stages + for idx in range(num_of_stages): + if idx == 0: + tensor_metadata_list[0] = tensor_metadata + else: + new_size = list(tensor_metadata.size) + new_size[0] = connector_metadata.afd_tokens_lens[idx] + tensor_metadata_list[idx] = TensorMetadata( + tensor_metadata.device, + tensor_metadata.dtype, + torch.Size(new_size), + ) + return tensor_metadata_list + + def _send_metadata( + self, + metadata: AFDConnectorMetadata, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + + tensor_metadata = TensorMetadata( + hidden_states.device.type, hidden_states.dtype, hidden_states.size() + ) + metadata_tuple = (metadata, tensor_metadata) + process_group.send_object(metadata_tuple, dst=dst) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, metadata + ) + + def _recv_metadata( + self, + src: int, + process_group: GroupCoordinator, + ) -> None: + (self._current_afd_connector_metadata, tensor_metadata) = ( + process_group.recv_object(src=src) + ) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, self._current_afd_connector_metadata + ) + + def _send_hidden_states( + self, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + assert not hidden_states.is_cpu, "Hidden states must be on GPU" + torch.distributed.send( + hidden_states, + dst=process_group.ranks[dst], + group=process_group.device_group, + ) + + def _recv_hidden_states( + self, + src: int, + process_group: GroupCoordinator, + tensor_metadata: TensorMetadata, + ) -> tuple[torch.Tensor, list]: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return {}, [] + assert src < process_group.world_size, f"Invalid src rank ({src})" + + hidden_states = torch.empty( + tensor_metadata.size, + dtype=tensor_metadata.dtype, + device=tensor_metadata.device, + ) + torch.distributed.recv( + hidden_states, + src=process_group.ranks[src], + group=process_group.device_group, + ) + return hidden_states, [] + + # ------------------------------------------------------------------------- + # attn -> ffn + # ------------------------------------------------------------------------- + + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ) -> None: + """ + Called by ATTN side to send intermediate tensors + generated by ATTN instances to FFN. + """ + try: + dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size + if metadata.layer_idx == 0 and metadata.stage_idx == 0: + self._send_metadata(metadata, hidden_states, dst, self.a2e_group) + self._current_afd_connector_metadata = metadata + self._send_hidden_states(hidden_states, dst, self.a2e_group) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the ATTN side to receive MOE output intermediate tensors, + possibly dispatching from the receiver to other GPUs. + """ + src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size + stage_idx = ( + self.recv_ffn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.e2a_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self.recv_ffn_output_counter = ( + self.recv_ffn_output_counter + 1 + ) % self._current_afd_connector_metadata.num_of_stages + return hidden_states, self._current_afd_connector_metadata + + # ------------------------------------------------------------------------- + # ffn -> attn + # ------------------------------------------------------------------------- + + def send_ffn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """ + Called by FFN side to send intermediate tensors generated by FFN + instances back to the sender (should be the same GPU as source). + """ + dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size + self._send_hidden_states(hidden_states, dst, self.e2a_group) + self.recv_attn_output_counter += 1 + if ( + self.recv_attn_output_counter + % ( + self._current_afd_connector_metadata.num_of_stages + * self.num_hidden_layers + ) + == 0 + ): + self._need_recv_metadata = True + self.recv_attn_output_counter = 0 + + def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the FFN side to receive intermediate tensors from ATTN. + Handles receiving and possibly dispatching tensors. + """ + src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size + if self._need_recv_metadata: + self._recv_metadata(src, self.a2e_group) + self._need_recv_metadata = False + + stage_idx = ( + self.recv_attn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + layer_idx = ( + self.recv_attn_output_counter + // self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.a2e_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self._current_afd_connector_metadata.layer_idx = layer_idx + return hidden_states, self._current_afd_connector_metadata +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg + from vllm.config import VllmConfig from vllm.distributed.parallel_state import ( GroupCoordinator, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7938cff98c354..6b7842b00f54a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1670,7 +1670,7 @@ class DeepseekV2ForCausalLM( return hidden_states def compute_ffn_output( - self, current_layer_idx, hidden_states + self, hidden_states, current_layer_idx ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) return hidden_states diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 7077f1a22e8d7..c012c26f83afa 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -11,12 +11,16 @@ from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata, +) +from vllm.forward_context import AFDMetadata, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3TextConfig +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield from .interfaces import SupportsPP from .utils import ( @@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module): config: Step3TextConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + afd_config: AFDConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.afd_role = afd_config.afd_role if afd_config is not None else None - self.self_attn = Step3TextAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - norm_eps=config.rms_norm_eps, - max_position_embedding=config.max_position_embedding, - head_dim=config.head_dim, - share_q_dim=config.share_q_dim, - rope_parameters=config.rope_parameters, - prefix=f"{prefix}.self_attn", - ) - - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) - moe_layers_enum = getattr(config, "moe_layers_enum", None) - if moe_layers_enum is not None: - moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] - else: - # Default to 1dense. - moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] - - if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock( - config=config, quant_config=quant_config, prefix=f"{prefix}.moe" - ) - self.share_expert = Step3TextMLP( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = Step3TextAttention( hidden_size=self.hidden_size, - intermediate_size=config.share_expert_dim, - hidden_act="silu", + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.share_expert", + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", ) - self.use_moe = True - else: - self.mlp = Step3TextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.use_moe = False + + self.layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + + if self.afd_role is None or self.afd_role == "ffn": + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if self.layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.use_moe = True + else: + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.use_moe = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module): hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + if self.use_moe: share_output = self.share_expert(hidden_states) moe_output = self.moe(hidden_states) @@ -309,6 +322,25 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ): + pass + + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + logger.info(f"{type(hidden_states)=}") + return hidden_states + @support_torch_compile class Step3TextModel(nn.Module): @@ -317,6 +349,8 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + logger.info(f"{quant_config=}") + afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -336,6 +370,7 @@ class Step3TextModel(nn.Module): config=config, cache_config=cache_config, quant_config=quant_config, + afd_config=afd_config, prefix=prefix, ), prefix=f"{prefix}.layers", @@ -352,6 +387,51 @@ class Step3TextModel(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + recv_handle = None + logger.info(f"{__file__}: forward with afd called, may blocked here") + for layer in islice(self.layers, self.start_layer, self.end_layer): + afd_connector = afd_metadata.afd_connector + afd_metadata.afd_stage_idx = dbo_current_ubatch_id() + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + + if recv_handle is not None: + for work in recv_handle: + work.wait() + current_hidden, residual = layer(positions, hidden_states, residual) + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=afd_metadata.afd_stage_idx, + seq_len=current_hidden.shape[0], + dtype=current_hidden.dtype, + device=current_hidden.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(current_hidden, metadata) + + if dbo_enabled(): + dbo_yield() + + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + if recv_handle is not None: + for work in recv_handle: + work.wait() + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -370,8 +450,19 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata is not None: + hidden_states, residual = self.forward_with_afd( + hidden_states, + residual, + positions, + afd_metadata, + ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -384,6 +475,15 @@ class Step3TextModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, + hidden_states, + layer_idx, + ) -> torch.Tensor | IntermediateTensors: + logger.info(f"{type(self.layers)=}, {type(layer_idx)=}") + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class Step3TextForCausalLM(nn.Module, SupportsPP): def __init__( @@ -398,6 +498,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.config = config self.vllm_config = vllm_config + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: @@ -429,11 +534,20 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + logger.info(f"{__file__}: load_weights!") qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) ( @@ -466,6 +580,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -477,9 +592,17 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: + logger.info( + f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " + f"is_common: {self.is_common_weight(name)}" + ) + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if any( disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params @@ -498,6 +621,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): param_name, weight_name, shard_id = mapping if weight_name not in name: continue + + if self.afd_role is not None and self.afd_role == "attention": + continue + name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -521,12 +648,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): loaded_params.add(name) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue for ( param_name, weight_name, start_idx, end_idx, ) in qkv_params_mapping: + logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -552,3 +686,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + def is_moe_weight(self, name): + if ( + "shared_expert" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + return True + return False diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index e5038e56a2708..e16cb53e9f194 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.language_model.compute_ffn_output( + hidden_states, current_layer_idx + ) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 82dd80d267182..fff775a9f8241 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -642,6 +642,25 @@ class GPUModelRunner( with_stack=False, ) + profile_dir = ( + "./profiler_logs/attn" + if self.afd_config is not None and self.afd_config.afd_role == "attention" + else "./profiler_logs/normal" + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 + 4000, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2969,6 +2988,38 @@ class GPUModelRunner( ) return afd_metadata + def _build_afd_metadata( + self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int + ): + afd_metadata = None + if self.afd_config: + # For prefill, compute tokens per stage based on actual token + # counts + afd_tokens_start_loc = [0] + afd_tokens_lens = [] + if ubatch_slices and len(ubatch_slices) > 1: + afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] + afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] + logger.info( + f"afd_tokens_start_loc: {afd_tokens_start_loc} " + f"afd_reqs_start_loc: {afd_reqs_start_loc} " + f"ubatch_slices: {ubatch_slices}" + ) + afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] + else: + afd_tokens_start_loc = [0] + afd_reqs_start_loc = [0] + afd_tokens_lens = [num_tokens_unpadded] + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + num_of_stages=len(ubatch_slices) if ubatch_slices else 1, + ) + return afd_metadata + @torch.inference_mode() def execute_model( self, @@ -5517,6 +5568,11 @@ class GPUModelRunner( if hasattr(self, "afd_connector") and self.afd_connector: self.afd_connector.init_afd_connector() + def initialize_afd_connector(self) -> None: + """Initialize AFD connector if available.""" + if hasattr(self, "afd_connector") and self.afd_connector: + self.afd_connector.init_afd_connector() + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9e17c718c5513..8a1c9d90abbbf 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -127,6 +127,10 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None + if ( + vllm_config.parallel_config.enable_expert_parallel + and not vllm_config.afd_config + ): if ( vllm_config.parallel_config.enable_expert_parallel and not vllm_config.afd_config From cd16bcff1ea7373d706c69abe8976609d9854aa8 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 15:56:20 +0800 Subject: [PATCH 07/19] [Chore] resolve some bugs due to merge --- vllm/v1/worker/gpu_ffn_model_runner.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/gpu_ubatch_wrapper.py | 4 ---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index cb08c9c05ae58..cd9940ef5e7a7 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -220,7 +220,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): hidden_states, dim=0 ) ffn_output = self.model.compute_ffn_output( - current_layer_idx, gathered_hidden_states + gathered_hidden_states, current_layer_idx ) # Extract the output corresponding to current rank start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() @@ -229,7 +229,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: # Single TP case rank_ffn_output = self.model.compute_ffn_output( - current_layer_idx, hidden_states + hidden_states, current_layer_idx ) return rank_ffn_output diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fff775a9f8241..ed6b9cc98f3a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3211,8 +3211,9 @@ class GPUModelRunner( record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): - logger.info(f"input_ids: {input_ids.shape}") - if inputs_embeds: + if input_ids is not None: + logger.info(f"input_ids: {input_ids.shape}") + if inputs_embeds is not None: logger.info(f"inputs_embeds: {inputs_embeds.shape}") model_output = self._model_forward( input_ids=input_ids, diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 8a1c9d90abbbf..9e17c718c5513 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -127,10 +127,6 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None - if ( - vllm_config.parallel_config.enable_expert_parallel - and not vllm_config.afd_config - ): if ( vllm_config.parallel_config.enable_expert_parallel and not vllm_config.afd_config From f74bb82909d5a12749fc9aeeb676b52bc7d767be Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 15:56:43 +0800 Subject: [PATCH 08/19] [Chore] code lint --- vllm/v1/worker/gpu_ffn_model_runner.py | 6 ++++-- vllm/v1/worker/gpu_model_runner.py | 8 ++++++-- vllm/v1/worker/gpu_ubatch_wrapper.py | 9 ++++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index cd9940ef5e7a7..3f85267d8536a 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -130,8 +130,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): try: hidden_states, recv_metadata = self.connector.recv_attn_output() - if hasattr(self.connector, 'dp_metadata_list'): - dp_metadata = self.connector.dp_metadata_list.get(recv_metadata.stage_idx, None) + if hasattr(self.connector, "dp_metadata_list"): + dp_metadata = self.connector.dp_metadata_list.get( + recv_metadata.stage_idx, None + ) else: dp_metadata = None current_layer_idx = recv_metadata.layer_idx diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ed6b9cc98f3a2..a09f292d98e14 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3192,7 +3192,9 @@ class GPUModelRunner( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False - afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) self.profiler.step() # Run the model. @@ -4326,7 +4328,9 @@ class GPUModelRunner( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded - afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9e17c718c5513..7a44b6cbd42b9 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -405,9 +405,12 @@ class UBatchWrapper: afd_metadata.input_ids_list.append(sliced_input_ids) afd_metadata.positions_list.append(sliced_positions) afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) - afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors) + afd_metadata.intermediate_tensors_list.append( + sliced_intermediate_tensors + ) afd_metadata.attn_metadata_list.append( - attn_metadata[i] if attn_metadata is not None else None) + attn_metadata[i] if attn_metadata is not None else None + ) afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) return afd_metadata @@ -481,7 +484,7 @@ class UBatchWrapper: # num_tokens, we don't have a non-ubatched one. Without this # check, the cudagraph wrapper will try to capture a cudagraph # for this shape during a normal run. - + if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None if batch_descriptor.num_tokens in self.cudagraphs: From 26ddfa299cc2835ab5ac9d0e6f3ddba3b85e5db0 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 17:02:39 +0800 Subject: [PATCH 09/19] [Chore] remove duplicate code --- vllm/v1/worker/gpu_model_runner.py | 56 ------------------------------ 1 file changed, 56 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a09f292d98e14..2409c9071f94b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -642,25 +642,6 @@ class GPUModelRunner( with_stack=False, ) - profile_dir = ( - "./profiler_logs/attn" - if self.afd_config is not None and self.afd_config.afd_role == "attention" - else "./profiler_logs/normal" - ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=6000 + 4000, warmup=1, active=30, repeat=1 - ), - on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), - record_shapes=True, - profile_memory=False, - with_stack=False, - ) - def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2988,38 +2969,6 @@ class GPUModelRunner( ) return afd_metadata - def _build_afd_metadata( - self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int - ): - afd_metadata = None - if self.afd_config: - # For prefill, compute tokens per stage based on actual token - # counts - afd_tokens_start_loc = [0] - afd_tokens_lens = [] - if ubatch_slices and len(ubatch_slices) > 1: - afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] - afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] - logger.info( - f"afd_tokens_start_loc: {afd_tokens_start_loc} " - f"afd_reqs_start_loc: {afd_reqs_start_loc} " - f"ubatch_slices: {ubatch_slices}" - ) - afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] - else: - afd_tokens_start_loc = [0] - afd_reqs_start_loc = [0] - afd_tokens_lens = [num_tokens_unpadded] - afd_metadata = AFDMetadata( - afd_tokens_start_loc=afd_tokens_start_loc, - afd_reqs_start_loc=afd_reqs_start_loc, - afd_stage_idx=0, - afd_connector=self.afd_connector, - afd_tokens_lens=afd_tokens_lens, - num_of_stages=len(ubatch_slices) if ubatch_slices else 1, - ) - return afd_metadata - @torch.inference_mode() def execute_model( self, @@ -5573,11 +5522,6 @@ class GPUModelRunner( if hasattr(self, "afd_connector") and self.afd_connector: self.afd_connector.init_afd_connector() - def initialize_afd_connector(self) -> None: - """Initialize AFD connector if available.""" - if hasattr(self, "afd_connector") and self.afd_connector: - self.afd_connector.init_afd_connector() - def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. From 8276320a8a691adf8747d4d93b12e4c2de208e2d Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 17:03:15 +0800 Subject: [PATCH 10/19] [Bugfix] compute ffn output param order --- vllm/v1/worker/gpu_ffn_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index 3f85267d8536a..921824f408552 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -351,7 +351,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): hidden_states, dim=0 ) ffn_output = self.model.compute_ffn_output( - current_layer_idx, gathered_hidden_states + gathered_hidden_states, current_layer_idx ) # Extract the output corresponding to current rank @@ -361,7 +361,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: # Single TP case rank_ffn_output = self.model.compute_ffn_output( - current_layer_idx, hidden_states + hidden_states, current_layer_idx ) return rank_ffn_output From 6a8d35a9b6c36136c56c470e3bfbfa51a2c44c15 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 17:32:42 +0800 Subject: [PATCH 11/19] [Chore] remove p2p connector duplicate code --- .../afd_connector/p2p_connector.py | 321 ------------------ 1 file changed, 321 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 3c3359fae96f5..dcb781c55540e 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -308,324 +308,3 @@ class P2PAFDConnector(AFDConnectorBase): self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx return hidden_states, self._current_afd_connector_metadata -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import re -from datetime import timedelta - -import torch -from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg - -from vllm.config import VllmConfig -from vllm.distributed.parallel_state import ( - GroupCoordinator, - TensorMetadata, - init_afd_process_group, - init_model_parallel_group, -) -from vllm.logger import init_logger -from vllm.forward_context import ( - DPMetadata, - get_forward_context, -) - -from .base import AFDConnectorBase -from .metadata import AFDConnectorMetadata - -logger = init_logger(__name__) - - -class DefaultProcessGroupSwitcher: - def __init__(self, default_group, new_default_group): - self.default_group = default_group - self.new_default_group = new_default_group - - def __enter__(self): - _update_default_pg(self.new_default_group) - - def __exit__(self, exc_type, exc_value, traceback): - _update_default_pg(self.default_group) - - -class P2PAFDConnector(AFDConnectorBase): - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ) -> None: - self.rank = rank - self.local_rank = local_rank - self.config = config - self._initialized: bool = False - self._need_recv_metadata: bool = True - self._tensor_metadata_list: dict[int, TensorMetadata] = {} - self._current_afd_connector_metadata: AFDConnectorMetadata | None = None - self.num_hidden_layers: int = ( - self.config.model_config.hf_config.num_hidden_layers - ) - self.recv_attn_output_counter: int = 0 - self.recv_ffn_output_counter: int = 0 - self.dp_metadata_list: dict[int, DPMetadata] = {} - - def close(self) -> None: - """Close the connector and release resources.""" - # TODO: Implement proper resource clean up if needed. - pass - - def init_afd_connector(self) -> None: - """Initialize the AFD connector.""" - afd_size = self.config.afd_config.afd_extra_config.get("afd_size") - role = self.config.afd_config.afd_role - attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) - world_rank = self.rank if role == "attention" else self.rank + attn_size - afd_pg = init_afd_process_group( - backend="nccl", - init_method=( - f"tcp://{self.config.afd_config.afd_host}" - f":{self.config.afd_config.afd_port}" - ), - world_size=ffn_size + attn_size, - rank=world_rank, - group_name="afd", - timeout=timedelta(minutes=2), - ) - - # Construct rank lists for sub groups. - # Each group contains one attention and one ffn rank. - ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] - attn_ranks = [i for i in range(attn_size)] - assert len(ffn_ranks) == len(attn_ranks), ( - "ffn_ranks and attn_ranks must have the same length" - ) - default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) - with default_pg_switcher: - sub_group_ranks = [] - for i in range(len(ffn_ranks)): - ranks = [attn_ranks[i], ffn_ranks[i]] - sub_group_ranks.append(ranks) - # Create two independent groups: - # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) - # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) - # The communication domain (rank range) is the same, but different group_name - # creates independent groups. - self.a2e_group = init_model_parallel_group( - sub_group_ranks, - self.local_rank, - backend="nccl", - group_name="a2e", - ) - self.e2a_group = init_model_parallel_group( - sub_group_ranks, - self.local_rank, - backend="nccl", - group_name="e2a", - ) - - self._initialized = True - - def is_initialized(self) -> bool: - """Check if the connector is initialized and ready to use. - - Returns: - bool: True if the connector is initialized, False otherwise. - """ - return self._initialized - - def _build_tensor_metadata_list( - self, - tensor_metadata: TensorMetadata, - connector_metadata: AFDConnectorMetadata, - ) -> dict[int, TensorMetadata]: - tensor_metadata_list = {} - num_of_stages = connector_metadata.num_of_stages - for idx in range(num_of_stages): - if idx == 0: - tensor_metadata_list[0] = tensor_metadata - else: - new_size = list(tensor_metadata.size) - new_size[0] = connector_metadata.afd_tokens_lens[idx] - tensor_metadata_list[idx] = TensorMetadata( - tensor_metadata.device, - tensor_metadata.dtype, - torch.Size(new_size), - ) - return tensor_metadata_list - - def _send_metadata( - self, - metadata: AFDConnectorMetadata, - hidden_states: torch.Tensor, - dst: int, - process_group: GroupCoordinator, - ) -> None: - if not torch.distributed.is_initialized() or process_group.world_size == 1: - return [] - assert dst < process_group.world_size, f"Invalid dst rank ({dst})" - - tensor_metadata = TensorMetadata( - hidden_states.device.type, hidden_states.dtype, hidden_states.size() - ) - metadata_tuple = (metadata, tensor_metadata) - process_group.send_object(metadata_tuple, dst=dst) - self._tensor_metadata_list = self._build_tensor_metadata_list( - tensor_metadata, metadata - ) - - def _recv_metadata( - self, - src: int, - process_group: GroupCoordinator, - ) -> None: - (self._current_afd_connector_metadata, tensor_metadata) = ( - process_group.recv_object(src=src) - ) - self._tensor_metadata_list = self._build_tensor_metadata_list( - tensor_metadata, self._current_afd_connector_metadata - ) - if self.config.parallel_config.data_parallel_size > 1: - logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) - for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): - num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] - self.dp_metadata_list[stage_idx] = DPMetadata.make( - self.config.parallel_config, - num_tokens_per_ubatch, - torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, - device="cpu", dtype=torch.int32), - ) - logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) - - def _send_hidden_states( - self, - hidden_states: torch.Tensor, - dst: int, - process_group: GroupCoordinator, - ) -> None: - if not torch.distributed.is_initialized() or process_group.world_size == 1: - return [] - assert dst < process_group.world_size, f"Invalid dst rank ({dst})" - assert not hidden_states.is_cpu, "Hidden states must be on GPU" - torch.distributed.send( - hidden_states, - dst=process_group.ranks[dst], - group=process_group.device_group, - ) - - def _recv_hidden_states( - self, - src: int, - process_group: GroupCoordinator, - tensor_metadata: TensorMetadata, - ) -> tuple[torch.Tensor, list]: - if not torch.distributed.is_initialized() or process_group.world_size == 1: - return {}, [] - assert src < process_group.world_size, f"Invalid src rank ({src})" - - hidden_states = torch.empty( - tensor_metadata.size, - dtype=tensor_metadata.dtype, - device=tensor_metadata.device, - ) - torch.distributed.recv( - hidden_states, - src=process_group.ranks[src], - group=process_group.device_group, - ) - return hidden_states, [] - - # ------------------------------------------------------------------------- - # attn -> ffn - # ------------------------------------------------------------------------- - - def send_attn_output( - self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata - ) -> None: - """ - Called by ATTN side to send intermediate tensors - generated by ATTN instances to FFN. - """ - try: - dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size - if metadata.layer_idx == 0 and metadata.stage_idx == 0: - self._send_metadata(metadata, hidden_states, dst, self.a2e_group) - self._current_afd_connector_metadata = metadata - self._send_hidden_states(hidden_states, dst, self.a2e_group) - except Exception as e: - raise RuntimeError(f"Communication error: {e}") - - def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: - """ - Called by the ATTN side to receive MOE output intermediate tensors, - possibly dispatching from the receiver to other GPUs. - """ - src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size - stage_idx = ( - self.recv_ffn_output_counter - % self._current_afd_connector_metadata.num_of_stages - ) - hidden_states, work_list = self._recv_hidden_states( - src, - self.e2a_group, - self._tensor_metadata_list[stage_idx], - ) - self._current_afd_connector_metadata.recv_handle_list = work_list - self.recv_ffn_output_counter = ( - self.recv_ffn_output_counter + 1 - ) % self._current_afd_connector_metadata.num_of_stages - return hidden_states, self._current_afd_connector_metadata - - # ------------------------------------------------------------------------- - # ffn -> attn - # ------------------------------------------------------------------------- - - def send_ffn_output( - self, - hidden_states: torch.Tensor, - metadata: AFDConnectorMetadata, - ) -> None: - """ - Called by FFN side to send intermediate tensors generated by FFN - instances back to the sender (should be the same GPU as source). - """ - dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size - self._send_hidden_states(hidden_states, dst, self.e2a_group) - self.recv_attn_output_counter += 1 - if ( - self.recv_attn_output_counter - % ( - self._current_afd_connector_metadata.num_of_stages - * self.num_hidden_layers - ) - == 0 - ): - self._need_recv_metadata = True - self.recv_attn_output_counter = 0 - - def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: - """ - Called by the FFN side to receive intermediate tensors from ATTN. - Handles receiving and possibly dispatching tensors. - """ - src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size - if self._need_recv_metadata: - self._recv_metadata(src, self.a2e_group) - self._need_recv_metadata = False - - stage_idx = ( - self.recv_attn_output_counter - % self._current_afd_connector_metadata.num_of_stages - ) - layer_idx = ( - self.recv_attn_output_counter - // self._current_afd_connector_metadata.num_of_stages - ) - hidden_states, work_list = self._recv_hidden_states( - src, - self.a2e_group, - self._tensor_metadata_list[stage_idx], - ) - self._current_afd_connector_metadata.recv_handle_list = work_list - self._current_afd_connector_metadata.layer_idx = layer_idx - self._current_afd_connector_metadata.stage_idx = stage_idx - return hidden_states, self._current_afd_connector_metadata From 11d7d5bf594c3ef743fef4b496705d5e39290034 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:02:07 +0800 Subject: [PATCH 12/19] [Chore] some log info --- vllm/model_executor/models/step3_text.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index c012c26f83afa..6355573a68774 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -303,6 +303,10 @@ class Step3TextDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) + # query, key and positions must have the same number of tokens + # /model_executor/layers/rotary_embedding/base.py + # positions.shape=torch.Size([8192]), hidden_states.shape=torch.Size([4096, 3712]) + logger.info(f"{positions.shape=}, {hidden_states.shape=}") hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -408,7 +412,9 @@ class Step3TextModel(nn.Module): if recv_handle is not None: for work in recv_handle: work.wait() + logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}") current_hidden, residual = layer(positions, hidden_states, residual) + logger.info(f"create attn metadata: {current_hidden.shape=}") metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=afd_metadata.afd_stage_idx, @@ -580,7 +586,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - logger.info(f"{params_dict.keys()=}") + # logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -592,10 +598,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - logger.info( - f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " - f"is_common: {self.is_common_weight(name)}" - ) + # logger.info( + # f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " + # f"is_common: {self.is_common_weight(name)}" + # ) if self.afd_role == "attention" and self.is_moe_weight(name): continue @@ -660,7 +666,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): start_idx, end_idx, ) in qkv_params_mapping: - logger.info(f"{weight_name=}, {name=}") + # logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) From 65ea10c8f453e5294370b5f98d08280f49936602 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:03:47 +0800 Subject: [PATCH 13/19] [Chore] bring back deleted code --- .../afd_connector/p2p_connector.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index dcb781c55540e..58be99bf117ea 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -15,6 +15,10 @@ from vllm.distributed.parallel_state import ( init_model_parallel_group, ) from vllm.logger import init_logger +from vllm.forward_context import ( + DPMetadata, + get_forward_context, +) from .base import AFDConnectorBase from .metadata import AFDConnectorMetadata @@ -59,6 +63,7 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 + self.dp_metadata_list: dict[int, DPMetadata] = {} def close(self) -> None: """Close the connector and release resources.""" @@ -175,6 +180,19 @@ class P2PAFDConnector(AFDConnectorBase): self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) + logger.info(f"{self.config.parallel_config.data_parallel_size=}") + if self.config.parallel_config.data_parallel_size > 1: + logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): + num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] + logger.info(f"{stage_idx=}, {num_tokens_per_ubatch=}") + self.dp_metadata_list[stage_idx] = DPMetadata.make( + self.config.parallel_config, + num_tokens_per_ubatch, + torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, + device="cpu", dtype=torch.int32), + ) + logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) def _send_hidden_states( self, @@ -307,4 +325,5 @@ class P2PAFDConnector(AFDConnectorBase): ) self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx + self._current_afd_connector_metadata.stage_idx = stage_idx return hidden_states, self._current_afd_connector_metadata From bde36017fa2324f26e11f65602125df282899fd5 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:04:47 +0800 Subject: [PATCH 14/19] [Chore] adjust log info --- vllm/v1/worker/gpu_ffn_model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index 921824f408552..79ba7e6eb5c2e 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -137,10 +137,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: dp_metadata = None current_layer_idx = recv_metadata.layer_idx - logger.info( - f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" - f" dp_metadata: {dp_metadata}" - ) + # logger.info( + # f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" + # f" dp_metadata: {dp_metadata}" + # ) num_tokens = hidden_states.shape[0] if recv_metadata is not None and recv_metadata.recv_handle_list is not None: for work in recv_metadata.recv_handle_list: From 6d305dda383ee6107125a7605fcefbd7d80e84c7 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:11:47 +0800 Subject: [PATCH 15/19] [Chore] add p2p connector debug log info --- .../distributed/afd_transfer/afd_connector/p2p_connector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 58be99bf117ea..0605facbfaffb 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -139,9 +139,11 @@ class P2PAFDConnector(AFDConnectorBase): for idx in range(num_of_stages): if idx == 0: tensor_metadata_list[0] = tensor_metadata + logger.info(f"build tensor metadata: stage_{idx=}, size={tensor_metadata.size}") else: new_size = list(tensor_metadata.size) new_size[0] = connector_metadata.afd_tokens_lens[idx] + logger.info(f"build tensor metadata: stage_{idx=}, {new_size=}, {connector_metadata.afd_tokens_lens=}") tensor_metadata_list[idx] = TensorMetadata( tensor_metadata.device, tensor_metadata.dtype, @@ -165,6 +167,7 @@ class P2PAFDConnector(AFDConnectorBase): ) metadata_tuple = (metadata, tensor_metadata) process_group.send_object(metadata_tuple, dst=dst) + logger.info(f"_send_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, metadata ) @@ -177,6 +180,7 @@ class P2PAFDConnector(AFDConnectorBase): (self._current_afd_connector_metadata, tensor_metadata) = ( process_group.recv_object(src=src) ) + logger.info(f"_recv_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) @@ -225,6 +229,7 @@ class P2PAFDConnector(AFDConnectorBase): dtype=tensor_metadata.dtype, device=tensor_metadata.device, ) + # logger.info(f"{__file__}: p2p recv hidden states: {hidden_states.shape=}, {tensor_metadata.size=}") torch.distributed.recv( hidden_states, src=process_group.ranks[src], @@ -262,6 +267,7 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) + logger.info(f"{stage_idx=}") hidden_states, work_list = self._recv_hidden_states( src, self.e2a_group, From 2a98ab3c8ea831c6a202225644edd671051e43ec Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Mon, 22 Dec 2025 14:29:03 +0800 Subject: [PATCH 16/19] [Chore]: step3 forward_with_afd --- vllm/model_executor/models/step3_text.py | 91 ++++++++++++++++-------- 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 6355573a68774..75a38805ef5aa 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -398,43 +398,72 @@ class Step3TextModel(nn.Module): positions: torch.Tensor, afd_metadata: AFDMetadata, ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() recv_handle = None - logger.info(f"{__file__}: forward with afd called, may blocked here") - for layer in islice(self.layers, self.start_layer, self.end_layer): - afd_connector = afd_metadata.afd_connector - afd_metadata.afd_stage_idx = dbo_current_ubatch_id() - if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + ubatch_hidden_states = [] + ubatch_residual = [] - if recv_handle is not None: - for work in recv_handle: - work.wait() - logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}") - current_hidden, residual = layer(positions, hidden_states, residual) - logger.info(f"create attn metadata: {current_hidden.shape=}") - metadata = AFDConnectorMetadata.create_attention_metadata( - layer_idx=layer.layer_idx, - stage_idx=afd_metadata.afd_stage_idx, - seq_len=current_hidden.shape[0], - dtype=current_hidden.dtype, - device=current_hidden.device, - num_of_stages=afd_metadata.num_of_stages, - afd_tokens_lens=afd_metadata.afd_tokens_lens, + start_idx = 0 + for pos in afd_metadata.positions_list: + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None ) - afd_connector.send_attn_output(current_hidden, metadata) + start_idx = end_idx - if dbo_enabled(): - dbo_yield() + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list - if recv_handle is not None: - for work in recv_handle: - work.wait() + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + else: + hidden_states = ubatch_hidden_states[stage_i] + + if recv_handle is not None: + for work in recv_handle: + work.wait() + + current_positions = afd_metadata.positions_list[stage_i] + hidden_states, residual = layer( + current_positions, hidden_states, residual + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}") + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i], recv_metadata = ( + afd_connector.recv_ffn_output() + ) + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None return hidden_states, residual From 27ae2e761c17e692a294ae1ce4a37c1092f270d7 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Mon, 22 Dec 2025 15:25:28 +0800 Subject: [PATCH 17/19] [Chore] clean up debug info --- .../afd_connector/p2p_connector.py | 30 +++++++++++-------- vllm/model_executor/models/step3_text.py | 15 ---------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 0605facbfaffb..f85679400d1c6 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -60,7 +60,7 @@ class P2PAFDConnector(AFDConnectorBase): self.num_hidden_layers: int = ( self.config.model_config.hf_config.num_hidden_layers ) - + self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 self.dp_metadata_list: dict[int, DPMetadata] = {} @@ -139,11 +139,9 @@ class P2PAFDConnector(AFDConnectorBase): for idx in range(num_of_stages): if idx == 0: tensor_metadata_list[0] = tensor_metadata - logger.info(f"build tensor metadata: stage_{idx=}, size={tensor_metadata.size}") else: new_size = list(tensor_metadata.size) new_size[0] = connector_metadata.afd_tokens_lens[idx] - logger.info(f"build tensor metadata: stage_{idx=}, {new_size=}, {connector_metadata.afd_tokens_lens=}") tensor_metadata_list[idx] = TensorMetadata( tensor_metadata.device, tensor_metadata.dtype, @@ -167,7 +165,6 @@ class P2PAFDConnector(AFDConnectorBase): ) metadata_tuple = (metadata, tensor_metadata) process_group.send_object(metadata_tuple, dst=dst) - logger.info(f"_send_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, metadata ) @@ -180,23 +177,32 @@ class P2PAFDConnector(AFDConnectorBase): (self._current_afd_connector_metadata, tensor_metadata) = ( process_group.recv_object(src=src) ) - logger.info(f"_recv_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) - logger.info(f"{self.config.parallel_config.data_parallel_size=}") if self.config.parallel_config.data_parallel_size > 1: - logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + logger.info( + "jcz recv_metadata num_of_stages:{}".format( + self._current_afd_connector_metadata.num_of_stages + ) + ) for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] - logger.info(f"{stage_idx=}, {num_tokens_per_ubatch=}") self.dp_metadata_list[stage_idx] = DPMetadata.make( self.config.parallel_config, num_tokens_per_ubatch, - torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, - device="cpu", dtype=torch.int32), + torch.tensor( + [num_tokens_per_ubatch] + * self.config.parallel_config.data_parallel_size, + device="cpu", + dtype=torch.int32, + ), ) - logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) + logger.info( + "jcz recv_metadata self.dp_metadata_list:{}".format( + self.dp_metadata_list + ) + ) def _send_hidden_states( self, @@ -229,7 +235,6 @@ class P2PAFDConnector(AFDConnectorBase): dtype=tensor_metadata.dtype, device=tensor_metadata.device, ) - # logger.info(f"{__file__}: p2p recv hidden states: {hidden_states.shape=}, {tensor_metadata.size=}") torch.distributed.recv( hidden_states, src=process_group.ranks[src], @@ -267,7 +272,6 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) - logger.info(f"{stage_idx=}") hidden_states, work_list = self._recv_hidden_states( src, self.e2a_group, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 75a38805ef5aa..a09bf0c8bd37f 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -303,10 +303,6 @@ class Step3TextDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - # query, key and positions must have the same number of tokens - # /model_executor/layers/rotary_embedding/base.py - # positions.shape=torch.Size([8192]), hidden_states.shape=torch.Size([4096, 3712]) - logger.info(f"{positions.shape=}, {hidden_states.shape=}") hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -342,7 +338,6 @@ class Step3TextDecoderLayer(nn.Module): hidden_states = share_output + moe_output else: hidden_states = self.mlp(hidden_states) - logger.info(f"{type(hidden_states)=}") return hidden_states @@ -353,7 +348,6 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - logger.info(f"{quant_config=}") afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -440,7 +434,6 @@ class Step3TextModel(nn.Module): ubatch_hidden_states[stage_i] = hidden_states ubatch_residual[stage_i] = residual - logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}") metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=stage_i, @@ -515,7 +508,6 @@ class Step3TextModel(nn.Module): hidden_states, layer_idx, ) -> torch.Tensor | IntermediateTensors: - logger.info(f"{type(self.layers)=}, {type(layer_idx)=}") hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) return hidden_states @@ -582,7 +574,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - logger.info(f"{__file__}: load_weights!") qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) ( @@ -615,7 +606,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - # logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -627,10 +617,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - # logger.info( - # f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " - # f"is_common: {self.is_common_weight(name)}" - # ) if self.afd_role == "attention" and self.is_moe_weight(name): continue @@ -695,7 +681,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): start_idx, end_idx, ) in qkv_params_mapping: - # logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) From 60d65cdf5c7a2c56af5a8d9f2f2193b85b4dc520 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Mon, 22 Dec 2025 15:37:07 +0800 Subject: [PATCH 18/19] [Chore] remove unused method --- vllm/model_executor/models/step3_text.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index a09bf0c8bd37f..eb723582a0ccb 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -322,14 +322,6 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual - def compute_attn_output( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - ): - pass - def compute_ffn_output(self, hidden_states): assert self.afd_role == "ffn" if self.use_moe: From 9f9a583f04e61a069b0e1e110e31972b7eb56e95 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Wed, 24 Dec 2025 16:12:35 +0800 Subject: [PATCH 19/19] [Refactor] p2p connector recv_ffn_output --- .../afd_connector/p2p_connector.py | 14 ++++------ vllm/model_executor/models/deepseek_v2.py | 28 +++---------------- vllm/model_executor/models/step3_text.py | 13 ++------- 3 files changed, 12 insertions(+), 43 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index f85679400d1c6..a48c6e0e752d1 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -225,7 +225,7 @@ class P2PAFDConnector(AFDConnectorBase): src: int, process_group: GroupCoordinator, tensor_metadata: TensorMetadata, - ) -> tuple[torch.Tensor, list]: + ) -> torch.Tensor: if not torch.distributed.is_initialized() or process_group.world_size == 1: return {}, [] assert src < process_group.world_size, f"Invalid src rank ({src})" @@ -240,7 +240,7 @@ class P2PAFDConnector(AFDConnectorBase): src=process_group.ranks[src], group=process_group.device_group, ) - return hidden_states, [] + return hidden_states # ------------------------------------------------------------------------- # attn -> ffn @@ -262,7 +262,7 @@ class P2PAFDConnector(AFDConnectorBase): except Exception as e: raise RuntimeError(f"Communication error: {e}") - def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + def recv_ffn_output(self) -> torch.Tensor: """ Called by the ATTN side to receive MOE output intermediate tensors, possibly dispatching from the receiver to other GPUs. @@ -272,16 +272,15 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) - hidden_states, work_list = self._recv_hidden_states( + hidden_states = self._recv_hidden_states( src, self.e2a_group, self._tensor_metadata_list[stage_idx], ) - self._current_afd_connector_metadata.recv_handle_list = work_list self.recv_ffn_output_counter = ( self.recv_ffn_output_counter + 1 ) % self._current_afd_connector_metadata.num_of_stages - return hidden_states, self._current_afd_connector_metadata + return hidden_states # ------------------------------------------------------------------------- # ffn -> attn @@ -328,12 +327,11 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_attn_output_counter // self._current_afd_connector_metadata.num_of_stages ) - hidden_states, work_list = self._recv_hidden_states( + hidden_states = self._recv_hidden_states( src, self.a2e_group, self._tensor_metadata_list[stage_idx], ) - self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx self._current_afd_connector_metadata.stage_idx = stage_idx return hidden_states, self._current_afd_connector_metadata diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6b7842b00f54a..9b9cf7a59126f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1347,19 +1347,13 @@ class DeepseekV2Model(nn.Module): afd_metadata: AFDMetadata, llama_4_scaling: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - recv_handle = None for layer in islice(self.layers, self.start_layer, self.end_layer): afd_connector = afd_metadata.afd_connector afd_metadata.afd_stage_idx = dbo_current_ubatch_id() if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + hidden_states = afd_connector.recv_ffn_output() - if recv_handle is not None: - for work in recv_handle: - work.wait() current_hidden, residual = layer( positions, hidden_states, residual, llama_4_scaling ) @@ -1377,12 +1371,7 @@ class DeepseekV2Model(nn.Module): if dbo_enabled(): dbo_yield() - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list - if recv_handle is not None: - for work in recv_handle: - work.wait() + hidden_states = afd_connector.recv_ffn_output() return hidden_states, residual @@ -1395,7 +1384,6 @@ class DeepseekV2Model(nn.Module): llama_4_scaling: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: forward_conext = get_forward_context() - recv_handle = None ubatch_hidden_states = [] ubatch_residual = [] @@ -1421,16 +1409,10 @@ class DeepseekV2Model(nn.Module): residual = ubatch_residual[stage_i] if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + hidden_states = afd_connector.recv_ffn_output() else: hidden_states = ubatch_hidden_states[stage_i] - if recv_handle is not None: - for work in recv_handle: - work.wait() - current_positions = afd_metadata.positions_list[stage_i] hidden_states, residual = layer( current_positions, hidden_states, residual, llama_4_scaling @@ -1452,9 +1434,7 @@ class DeepseekV2Model(nn.Module): # Recv last layer FFN output. for stage_i in range(afd_metadata.num_of_stages): - ubatch_hidden_states[stage_i], recv_metadata = ( - afd_connector.recv_ffn_output() - ) + ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output() # Re-assemble the batch hidden_states = torch.cat(ubatch_hidden_states, dim=0) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index eb723582a0ccb..6e87aeab49b3e 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -385,7 +385,6 @@ class Step3TextModel(nn.Module): afd_metadata: AFDMetadata, ) -> tuple[torch.Tensor, torch.Tensor]: forward_conext = get_forward_context() - recv_handle = None ubatch_hidden_states = [] ubatch_residual = [] @@ -409,16 +408,10 @@ class Step3TextModel(nn.Module): residual = ubatch_residual[stage_i] if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + hidden_states = afd_connector.recv_ffn_output() else: hidden_states = ubatch_hidden_states[stage_i] - if recv_handle is not None: - for work in recv_handle: - work.wait() - current_positions = afd_metadata.positions_list[stage_i] hidden_states, residual = layer( current_positions, hidden_states, residual @@ -439,9 +432,7 @@ class Step3TextModel(nn.Module): # Recv last layer FFN output. for stage_i in range(afd_metadata.num_of_stages): - ubatch_hidden_states[stage_i], recv_metadata = ( - afd_connector.recv_ffn_output() - ) + ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output() # Re-assemble the batch hidden_states = torch.cat(ubatch_hidden_states, dim=0)