diff --git a/examples/online_serving/afd_deepseek_v2/README.md b/examples/online_serving/afd_deepseek_v2/README.md new file mode 100644 index 0000000000000..4f8cf4a4e40db --- /dev/null +++ b/examples/online_serving/afd_deepseek_v2/README.md @@ -0,0 +1,19 @@ +# P2P Connector +P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances. + +When the --enable-dbo flag is currently enabled, the num_stage parameter becomes ineffective, and the actual number of microbatches is 2. + +Currently, the P2PConnector only supports scenarios where the number of dies of A equals that of F. Asymmetric configurations will be supported in future updates. + +1. Attn + +``` +vllm serve "/path/to/DeepSeek-V2-Lite" --data_parallel_size=2 --enable_expert_parallel --enforce_eager --enable-dbo --dbo-prefill-token-threshold 12 --dbo-decode-token-threshold 2 --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"2","afd_extra_config":{"afd_size":"2A2F"}}' + +``` + +2. FFN + +``` +vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"2", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}' +``` \ No newline at end of file diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 3c77fad41d077..25e6272cca1d9 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.config.attention import AttentionConfig +from vllm.config.afd import AFDConfig from vllm.config.cache import CacheConfig from vllm.config.compilation import ( CompilationConfig, @@ -66,6 +67,8 @@ __all__ = [ "KVEventsConfig", # From vllm.config.kv_transfer "KVTransferConfig", + # AFD (Attention FFN Disaggregation) configuration + "AFDConfig", # From vllm.config.load "LoadConfig", # From vllm.config.lora diff --git a/vllm/config/afd.py b/vllm/config/afd.py new file mode 100644 index 0000000000000..7277b568f72bd --- /dev/null +++ b/vllm/config/afd.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import Any, Literal + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class AFDConfig: + """Configuration for AFD (Attention FFN Disaggregation) distributed + computation.""" + + afd_connector: str = "dummy" + """The AFD connector for vLLM to communicate between attention and FFN + nodes. Available connectors: 'dummy', 'p2pconnector'""" + + afd_role: Literal["attention", "ffn"] = "attention" + """Role of this vLLM instance in AFD. 'attention' for attention workers, + 'ffn' for FFN servers.""" + + afd_port: int = 1239 + """Port number for stepmesh parameter server communication.""" + + afd_host: str = "127.0.0.1" + """Host address for stepmesh parameter server communication.""" + + num_afd_stages: int = 3 + """Number of pipeline stages for stage parallelism.""" + + num_attention_servers: int = 1 + """Number of attention servers.""" + + num_ffn_servers: int = 1 + """Number of FFN servers.""" + + afd_server_rank: int = 0 + """Rank of this AFD server.""" + + afd_extra_config: dict[str, Any] = field(default_factory=dict) + """Extra configuration for specific AFD connectors.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # AFD configuration affects the computation graph structure + # as it changes how FFN computation is performed + factors: list[Any] = [ + self.afd_connector, + self.afd_role, + self.num_afd_stages, + self.num_attention_servers, + self.num_ffn_servers, + ] + return hashlib.sha256(str(factors).encode()).hexdigest() + + @property + def is_attention_server(self) -> bool: + """Check if this instance is configured as an attention server.""" + return self.afd_role == "attention" + + @property + def is_ffn_server(self) -> bool: + """Check if this instance is configured as an FFN server.""" + return self.afd_role == "ffn" diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 11504fb083558..29d12a8ccff8e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -553,7 +553,7 @@ class ParallelConfig: if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. - + from vllm.v1.executor import ray_utils backend: DistributedExecutorBackend = "mp" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0439dc52e7e6f..610473e9a8e6d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -28,6 +28,7 @@ from vllm.utils import random_uuid from vllm.utils.hashing import safe_hash from .attention import AttentionConfig +from .afd import AFDConfig from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig @@ -227,6 +228,8 @@ class VllmConfig: """The configurations for event publishing.""" ec_transfer_config: ECTransferConfig | None = None """The configurations for distributed EC cache transfer.""" + afd_config: AFDConfig | None = None + """AFD (Attention FFN Disaggregation) configuration.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -318,6 +321,10 @@ class VllmConfig: vllm_factors.append(self.ec_transfer_config.compute_hash()) else: vllm_factors.append("None") + if self.afd_config: + vllm_factors.append(self.afd_config.compute_hash()) + else: + vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): additional_config_hash = safe_hash( @@ -872,16 +879,17 @@ class VllmConfig: if self.parallel_config.use_ubatching: a2a_backend = self.parallel_config.all2all_backend - assert a2a_backend in [ - "deepep_low_latency", - "deepep_high_throughput", - ], ( - "Microbatching currently only supports the deepep_low_latency and " - f"deepep_high_throughput all2all backend. {a2a_backend} is not " - "supported. To fix use --all2all-backend=deepep_low_latency or " - "--all2all-backend=deepep_high_throughput and install the DeepEP" - " kernels." - ) + if self.afd_config is None: + assert a2a_backend in [ + "deepep_low_latency", + "deepep_high_throughput", + ], ( + "Microbatching currently only supports the deepep_low_latency and " + f"deepep_high_throughput all2all backend. {a2a_backend} is not " + "supported. To fix use --all2all-backend=deepep_low_latency or " + "--all2all-backend=deepep_high_throughput and install the DeepEP" + " kernels." + ) if not self.model_config.disable_cascade_attn: self.model_config.disable_cascade_attn = True diff --git a/vllm/distributed/afd_transfer/__init__.py b/vllm/distributed/afd_transfer/__init__.py new file mode 100644 index 0000000000000..fc90a145a44fa --- /dev/null +++ b/vllm/distributed/afd_transfer/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD (Attention-FFN Disaggregation) transfer components. + +This module provides the distributed infrastructure for AFD, enabling +disaggregated FFN computation across different machines while keeping +attention computation local. +""" + +from .afd_connector import AFDConnectorBase, AFDConnectorFactory, AFDConnectorMetadata + +__all__ = ["AFDConnectorBase", "AFDConnectorMetadata", "AFDConnectorFactory"] diff --git a/vllm/distributed/afd_transfer/afd_connector/__init__.py b/vllm/distributed/afd_transfer/afd_connector/__init__.py new file mode 100644 index 0000000000000..56f6e91160af0 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD connector implementations for different transport backends.""" + +from .base import AFDConnectorBase +from .factory import AFDConnectorFactory +from .metadata import AFDConnectorMetadata + +__all__ = [ + "AFDConnectorBase", + "AFDConnectorFactory", + "AFDConnectorMetadata", +] diff --git a/vllm/distributed/afd_transfer/afd_connector/base.py b/vllm/distributed/afd_transfer/afd_connector/base.py new file mode 100644 index 0000000000000..97ebaeacaf86e --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/base.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AFDConnectorBase Class for Distributed AFD FFN computation + +The class provides the four core AFD communication interfaces: +1. send_attn_output(): Send attention output to FFN servers (Attention Worker) +2. recv_ffn_output(): Receive FFN computation result (Attention Worker) +3. recv_attn_output(): Receive attention output from workers (FFN Server) +4. send_ffn_output(): Send FFN computation result back (FFN Server) +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import torch + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + from .metadata import AFDConnectorMetadata + + +class AFDConnectorBase(ABC): + """ + Abstract base class for AFD connectors. + + This provides the four core interfaces for AFD communication between + attention workers and FFN servers. + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize the AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the connector and release resources.""" + raise NotImplementedError + + @abstractmethod + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + raise NotImplementedError + + @property + @abstractmethod + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + raise NotImplementedError + + def get_connector_rank(self) -> int: + """Get the rank of this connector.""" + return getattr(self, "rank", 0) + + def get_connector_local_rank(self) -> int: + """Get the local rank of this connector.""" + return getattr(self, "local_rank", 0) + + @abstractmethod + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> Any: + """Send attention output to FFN servers. + + Args: + hidden_states: Attention output tensor + metadata: AFD metadata containing layer_idx, stage_idx, seq_len info + + Returns: + Any: Handle for tracking this request (backend-specific) + """ + raise NotImplementedError + + @abstractmethod + def recv_ffn_output( + self, + handle: Any, + ) -> torch.Tensor: + """Wait for and receive FFN computation result. + + Args: + handle: Handle returned by send_attn_output() + + Returns: + torch.Tensor: FFN computation result + """ + raise NotImplementedError + + @abstractmethod + def recv_attn_output( + self, + timeout_ms: int | None = None, + ) -> tuple[torch.Tensor, "AFDConnectorMetadata"]: + """Receive attention output from attention workers. + + Args: + timeout_ms: Optional timeout in milliseconds + + Returns: + tuple: (hidden_states, metadata) + - hidden_states: Concatenated attention outputs + - metadata: Inferred AFD metadata containing + seq_lens and other info + """ + raise NotImplementedError + + @abstractmethod + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> None: + """Send FFN computation result back to attention workers. + + Args: + ffn_output: Computed FFN result + metadata: AFD metadata containing seq_lens + for splitting and routing info + """ + raise NotImplementedError diff --git a/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py b/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py new file mode 100644 index 0000000000000..c462646f9d0e9 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Dummy AFD Connector for testing and local development. + +This connector provides a no-op AFDConnectorBase interface, +useful for testing and development scenarios where actual +distributed FFN computation is not needed. +""" + +import time +from collections import deque +from typing import TYPE_CHECKING, Any + +import torch + +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class DummyAFDConnector(AFDConnectorBase): + """Dummy AFD connector that returns zero tensors. + + This connector is useful for: + 1. Testing AFD infrastructure without actual remote computation + 2. Development scenarios where FFN computation should be disabled + 3. Fallback behavior when remote FFN servers are unavailable + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize the dummy AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + """ + self.afd_config = config.afd_config + self.rank = rank + self.local_rank = local_rank + self._is_initialized = False + self.hidden_size = config.model_config.hf_config.hidden_size + self.num_stages = config.afd_config.num_afd_stages + + self.events: deque = deque(maxlen=self.num_stages) + + logger.info("DummyAFDConnector initialized for rank %s", rank) + + self.init_afd_connector() + + def init_afd_connector(self) -> None: + """Initialize the dummy connector. + + This is a no-op for the dummy connector. + """ + if self._is_initialized: + return + + logger.info("Initializing DummyAFDConnector (no-op)") + self._is_initialized = True + + def close(self) -> None: + """Close the dummy connector. + + This is a no-op for the dummy connector. + """ + if not self._is_initialized: + return + + logger.info("Closing DummyAFDConnector (no-op)") + self._is_initialized = False + + @property + def is_initialized(self) -> bool: + """Check if the connector is initialized. + + Returns: + bool: True if initialized, False otherwise + """ + return self._is_initialized + + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> Any: + """ + Send attention output to FFN servers (dummy implementation). + """ + logger.debug( + "DummyAFDConnector: send_attn_output layer=%s, stage=%s", + metadata.layer_idx, + metadata.stage_idx, + ) + + # Validate metadata consistency + if not metadata.validate_tensor_shape(hidden_states.shape): + raise ValueError( + "Tensor shape %s doesn't match metadata %s", + hidden_states.shape, + metadata, + ) + + if not metadata.is_single_sequence: + raise ValueError("Attention side should have single sequence") + + self.events.append((None, metadata)) + + return None + + def recv_ffn_output( + self, + timeout_ms: float | None = None, + ) -> torch.Tensor: + """Receive FFN computation result (dummy implementation).""" + logger.debug("DummyAFDConnector: recv_ffn_output timeout_ms=%s", timeout_ms) + + _, metadata = self.events.popleft() + seq_len = metadata.seq_lens[0] # Single sequence for attention side + return torch.zeros( + seq_len, + self.hidden_size, + dtype=metadata.dtype, + device=metadata.device, + ) + + def recv_attn_output( + self, + timeout_ms: int | None = None, + ) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Receive attention output from attention workers (dummy implementation). + """ + logger.debug("DummyAFDConnector: recv_attn_output timeout_ms=%s", timeout_ms) + + # Generate dummy data that simulates multiple attention workers + dummy_seq_lens = [ + 2, + 2, + 2, + ] # Variable sequence lengths from different workers + total_tokens = sum(dummy_seq_lens) + + dummy_tensor = torch.zeros( + total_tokens, self.hidden_size, dtype=torch.bfloat16, device="cuda" + ) + + # Create dummy metadata + dummy_metadata = AFDConnectorMetadata.create_ffn_metadata( + layer_idx=0, # Dummy layer + stage_idx=0, # Dummy stage + dtype=torch.bfloat16, + device=torch.device("cuda"), + seq_lens=dummy_seq_lens, + request_id=f"dummy_ffn_batch_{time.time()}", + ) + + # Cache metadata for send_ffn_output + self._current_metadata = dummy_metadata + time.sleep(1) + + return dummy_tensor, dummy_metadata + + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """Send FFN computation result back (dummy implementation).""" + logger.debug( + "DummyAFDConnector: send_ffn_output layer=%s, stage=%s", + metadata.layer_idx, + metadata.stage_idx, + ) + + # Validate that ffn_output shape matches metadata + if not metadata.validate_tensor_shape(ffn_output.shape): + logger.warning( + "FFN output shape %s doesn't match metadata %s", + ffn_output.shape, + metadata, + ) + + # Log the splitting information for debugging + logger.debug( + "DummyAFDConnector: Split FFN output into %s parts with lengths %s", + metadata.num_sequences, + metadata.seq_lens, + ) + + # Simulate splitting (for logging purposes) + if metadata.get_split_indices(): + split_outputs = torch.split(ffn_output, metadata.seq_lens, dim=0) + logger.debug( + "DummyAFDConnector: Split shapes: %s", + [s.shape for s in split_outputs], + ) + + time.sleep(1) + # No-op for dummy connector - just log the operation diff --git a/vllm/distributed/afd_transfer/afd_connector/factory.py b/vllm/distributed/afd_transfer/afd_connector/factory.py new file mode 100644 index 0000000000000..757c0ed9fc661 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/factory.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Factory for creating AFD connectors based on configuration.""" + +import importlib +from collections.abc import Callable +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +from .base import AFDConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class AFDConnectorFactory: + _registry: dict[str, Callable[[], type[AFDConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[AFDConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector( + cls, rank: int, local_rank: int, config: "VllmConfig" + ) -> AFDConnectorBase: + """Create an AFD connector based on the configuration. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + + Returns: + AFDConnectorBase: The created connector instance + + Raises: + ValueError: If the transport backend is not supported + ImportError: If required dependencies are not available + """ + afd_config = config.afd_config + connector_name = afd_config.afd_connector + + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, AFDConnectorBase) + return connector_cls(rank, local_rank, config) + + @classmethod + def get_connector_class(cls, connector_name: str) -> type[AFDConnectorBase]: + """Get the connector class for a given connector name. + + Args: + connector_name: The connector name + + Returns: + type[AFDConnectorBase]: The connector class + + Raises: + ValueError: If the connector name is not supported + """ + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + return cls._registry[connector_name]() + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. + +AFDConnectorFactory.register_connector( + "dummy", + "vllm.distributed.afd_transfer.afd_connector.dummy_connector", + "DummyAFDConnector", +) + +AFDConnectorFactory.register_connector( + "p2pconnector", + "vllm.distributed.afd_transfer.afd_connector.p2p_connector", + "P2PAFDConnector", +) diff --git a/vllm/distributed/afd_transfer/afd_connector/metadata.py b/vllm/distributed/afd_transfer/afd_connector/metadata.py new file mode 100644 index 0000000000000..af259ebd6691c --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/metadata.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD metadata definitions for communication between attention and +FFN workers.""" + +import time +import typing +from dataclasses import dataclass, field +from typing import Any + +import torch + + +class FFNNeedForwardData: + def __init__( + self, + moe_comm_method: typing.Any, + num_input_tokens: int, + with_prefill: bool, + total_num_scheduled_tokens: int | None, + is_dummy_run: bool = False, + ): + self.moe_comm_method = moe_comm_method + self.num_input_tokens = num_input_tokens + self.with_prefill = with_prefill + self.total_num_scheduled_tokens = total_num_scheduled_tokens + self.is_dummy_run = is_dummy_run + + +@dataclass +class AFDConnectorMetadata: + """Lightweight AFD metadata containing core information needed for + communication.""" + + layer_idx: int + stage_idx: int + seq_lens: list[int] # Length of each sequence, supports variable length and + # multiple sequences + dtype: torch.dtype + device: torch.device + topk_idx: torch.Tensor | None = None # indices token which expert to be sended + topk_weights: torch.Tensor | None = None # the expert weights + moe_expert_num: int | None = None # number of moe experts + shared_expert_num: int | None = None # number of share experts + scale: torch.Tensor | None = None # quant scale + expertTokenNumsOut: torch.Tensor | None = ( + None # The number of tokens received by each expert is used as input for the subsequent GMM. + ) + recv_handle_list: list[Any] | None = ( + None # the communication handles (list of Work objects returned by torch.distributed.irecv) + ) + + # Optional fields for debugging and extensibility + request_id: str | None = None + timestamp: float | None = None + """ffn need forward data""" + ffn_need_forward_data: FFNNeedForwardData | None = None + num_of_stages: int = 1 + afd_tokens_lens: list = field(default_factory=list) + + def __post_init__(self): + """Validate data consistency.""" + if not self.seq_lens: + raise ValueError("seq_lens cannot be empty") + if any(length <= 0 for length in self.seq_lens): + raise ValueError("All sequence lengths must be positive") + + @property + def total_tokens(self) -> int: + """Total number of tokens.""" + return sum(self.seq_lens) + + @property + def num_sequences(self) -> int: + """Number of sequences.""" + return len(self.seq_lens) + + @property + def is_single_sequence(self) -> bool: + """Whether this is a single sequence (attention side characteristic).""" + return len(self.seq_lens) == 1 + + @property + def is_multi_sequence(self) -> bool: + """Whether this is multiple sequences (FFN side characteristic).""" + return len(self.seq_lens) > 1 + + @classmethod + def create_attention_metadata( + cls, + layer_idx: int, + stage_idx: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + request_id: str | None = None, + ffn_need_forward_data: FFNNeedForwardData | None = None, + num_of_stages: int = 1, + afd_tokens_lens: list[int] = [], + ) -> "AFDConnectorMetadata": + """Create metadata for attention side (single sequence).""" + return cls( + layer_idx=layer_idx, + stage_idx=stage_idx, + seq_lens=[seq_len], + dtype=dtype, + device=device, + request_id=request_id, + ffn_need_forward_data=ffn_need_forward_data, + timestamp=time.time(), + num_of_stages=num_of_stages, + afd_tokens_lens=afd_tokens_lens, + ) + + @classmethod + def create_ffn_metadata( + cls, + layer_idx: int, + stage_idx: int, + seq_lens: list[int], + dtype: torch.dtype, + device: torch.device, + request_id: str | None = None, + ) -> "AFDConnectorMetadata": + """Create metadata for FFN side (multiple sequences).""" + return cls( + layer_idx=layer_idx, + stage_idx=stage_idx, + seq_lens=seq_lens.copy(), # Prevent external modification + dtype=dtype, + device=device, + request_id=request_id, + timestamp=time.time(), + ) + + def get_split_indices(self) -> list[int]: + """Get tensor split indices for FFN side output splitting.""" + if len(self.seq_lens) <= 1: + return [] + + indices = [] + cumsum = 0 + for length in self.seq_lens[:-1]: # Exclude the last one + cumsum += length + indices.append(cumsum) + return indices + + def validate_tensor_shape(self, tensor_shape: tuple[int, ...]) -> bool: + """Validate if tensor shape is consistent with metadata.""" + if len(tensor_shape) < 1: + return False + return tensor_shape[0] == self.total_tokens + + def to_dict(self) -> dict: + """Convert to dictionary format for serialization and debugging.""" + return { + "layer_idx": self.layer_idx, + "stage_idx": self.stage_idx, + "seq_lens": self.seq_lens, + "dtype": self.dtype, + "device": self.device, + "total_tokens": self.total_tokens, + "num_sequences": self.num_sequences, + "request_id": self.request_id, + "timestamp": self.timestamp, + } + + def __repr__(self) -> str: + """Friendly string representation.""" + return ( + f"AFDConnectorMetadata(layer={self.layer_idx}, " + f"stage={self.stage_idx}, seq_lens={self.seq_lens}, " + f"total_tokens={self.total_tokens}, dtype={self.dtype}, " + f"device={self.device}, request_id={self.request_id})" + ) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py new file mode 100644 index 0000000000000..a48c6e0e752d1 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg + +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + init_afd_process_group, + init_model_parallel_group, +) +from vllm.logger import init_logger +from vllm.forward_context import ( + DPMetadata, + get_forward_context, +) + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self.config = config + self._initialized: bool = False + self._need_recv_metadata: bool = True + self._tensor_metadata_list: dict[int, TensorMetadata] = {} + self._current_afd_connector_metadata: AFDConnectorMetadata | None = None + if getattr(self.config.model_config.hf_config, "text_config", None) is not None: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.text_config.num_hidden_layers + ) + else: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.num_hidden_layers + ) + + self.recv_attn_output_counter: int = 0 + self.recv_ffn_output_counter: int = 0 + self.dp_metadata_list: dict[int, DPMetadata] = {} + + def close(self) -> None: + """Close the connector and release resources.""" + # TODO: Implement proper resource clean up if needed. + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + world_rank = self.rank if role == "attention" else self.rank + attn_size + afd_pg = init_afd_process_group( + backend="nccl", + init_method=( + f"tcp://{self.config.afd_config.afd_host}" + f":{self.config.afd_config.afd_port}" + ), + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + + # Construct rank lists for sub groups. + # Each group contains one attention and one ffn rank. + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + assert len(ffn_ranks) == len(attn_ranks), ( + "ffn_ranks and attn_ranks must have the same length" + ) + default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = [attn_ranks[i], ffn_ranks[i]] + sub_group_ranks.append(ranks) + # Create two independent groups: + # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) + # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) + # The communication domain (rank range) is the same, but different group_name + # creates independent groups. + self.a2e_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="a2e", + ) + self.e2a_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="e2a", + ) + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def _build_tensor_metadata_list( + self, + tensor_metadata: TensorMetadata, + connector_metadata: AFDConnectorMetadata, + ) -> dict[int, TensorMetadata]: + tensor_metadata_list = {} + num_of_stages = connector_metadata.num_of_stages + for idx in range(num_of_stages): + if idx == 0: + tensor_metadata_list[0] = tensor_metadata + else: + new_size = list(tensor_metadata.size) + new_size[0] = connector_metadata.afd_tokens_lens[idx] + tensor_metadata_list[idx] = TensorMetadata( + tensor_metadata.device, + tensor_metadata.dtype, + torch.Size(new_size), + ) + return tensor_metadata_list + + def _send_metadata( + self, + metadata: AFDConnectorMetadata, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + + tensor_metadata = TensorMetadata( + hidden_states.device.type, hidden_states.dtype, hidden_states.size() + ) + metadata_tuple = (metadata, tensor_metadata) + process_group.send_object(metadata_tuple, dst=dst) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, metadata + ) + + def _recv_metadata( + self, + src: int, + process_group: GroupCoordinator, + ) -> None: + (self._current_afd_connector_metadata, tensor_metadata) = ( + process_group.recv_object(src=src) + ) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, self._current_afd_connector_metadata + ) + if self.config.parallel_config.data_parallel_size > 1: + logger.info( + "jcz recv_metadata num_of_stages:{}".format( + self._current_afd_connector_metadata.num_of_stages + ) + ) + for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): + num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] + self.dp_metadata_list[stage_idx] = DPMetadata.make( + self.config.parallel_config, + num_tokens_per_ubatch, + torch.tensor( + [num_tokens_per_ubatch] + * self.config.parallel_config.data_parallel_size, + device="cpu", + dtype=torch.int32, + ), + ) + logger.info( + "jcz recv_metadata self.dp_metadata_list:{}".format( + self.dp_metadata_list + ) + ) + + def _send_hidden_states( + self, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + assert not hidden_states.is_cpu, "Hidden states must be on GPU" + torch.distributed.send( + hidden_states, + dst=process_group.ranks[dst], + group=process_group.device_group, + ) + + def _recv_hidden_states( + self, + src: int, + process_group: GroupCoordinator, + tensor_metadata: TensorMetadata, + ) -> torch.Tensor: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return {}, [] + assert src < process_group.world_size, f"Invalid src rank ({src})" + + hidden_states = torch.empty( + tensor_metadata.size, + dtype=tensor_metadata.dtype, + device=tensor_metadata.device, + ) + torch.distributed.recv( + hidden_states, + src=process_group.ranks[src], + group=process_group.device_group, + ) + return hidden_states + + # ------------------------------------------------------------------------- + # attn -> ffn + # ------------------------------------------------------------------------- + + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ) -> None: + """ + Called by ATTN side to send intermediate tensors + generated by ATTN instances to FFN. + """ + try: + dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size + if metadata.layer_idx == 0 and metadata.stage_idx == 0: + self._send_metadata(metadata, hidden_states, dst, self.a2e_group) + self._current_afd_connector_metadata = metadata + self._send_hidden_states(hidden_states, dst, self.a2e_group) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_ffn_output(self) -> torch.Tensor: + """ + Called by the ATTN side to receive MOE output intermediate tensors, + possibly dispatching from the receiver to other GPUs. + """ + src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size + stage_idx = ( + self.recv_ffn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + hidden_states = self._recv_hidden_states( + src, + self.e2a_group, + self._tensor_metadata_list[stage_idx], + ) + self.recv_ffn_output_counter = ( + self.recv_ffn_output_counter + 1 + ) % self._current_afd_connector_metadata.num_of_stages + return hidden_states + + # ------------------------------------------------------------------------- + # ffn -> attn + # ------------------------------------------------------------------------- + + def send_ffn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """ + Called by FFN side to send intermediate tensors generated by FFN + instances back to the sender (should be the same GPU as source). + """ + dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size + self._send_hidden_states(hidden_states, dst, self.e2a_group) + self.recv_attn_output_counter += 1 + if ( + self.recv_attn_output_counter + % ( + self._current_afd_connector_metadata.num_of_stages + * self.num_hidden_layers + ) + == 0 + ): + self._need_recv_metadata = True + self.recv_attn_output_counter = 0 + + def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the FFN side to receive intermediate tensors from ATTN. + Handles receiving and possibly dispatching tensors. + """ + src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size + if self._need_recv_metadata: + self._recv_metadata(src, self.a2e_group) + self._need_recv_metadata = False + + stage_idx = ( + self.recv_attn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + layer_idx = ( + self.recv_attn_output_counter + // self._current_afd_connector_metadata.num_of_stages + ) + hidden_states = self._recv_hidden_states( + src, + self.a2e_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.layer_idx = layer_idx + self._current_afd_connector_metadata.stage_idx = stage_idx + return hidden_states, self._current_afd_connector_metadata diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f5ada5a009ec3..2661d28ba5dd7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,6 +41,14 @@ import torch.distributed import torch.distributed._functional_collectives as funcol import torch.distributed._symmetric_memory from torch.distributed import Backend, ProcessGroup +from torch.distributed.distributed_c10d import ( + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -322,7 +330,6 @@ class GroupCoordinator: self_device_group = None self_cpu_group = None - for ranks in group_ranks: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend @@ -1036,6 +1043,63 @@ _INNER_DP_WORLD: GroupCoordinator | None = None _NODE_COUNT: int | None = None +def init_afd_process_group( + backend: str | Backend = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str = None, + pg_options: Any | None = None, +): + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store." + ) + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = PrefixStore(group_name, store) + + pg_options_param_name = ( + "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" + ) + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + global _AFD + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + _AFD = pg + return pg + + +_WORLD: GroupCoordinator | None = None +_NODE_COUNT: int | None = None + + def get_world_group() -> GroupCoordinator: assert _WORLD is not None, "world group is not initialized" return _WORLD diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1442c83a1504a..487ef1af82b64 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,6 +35,7 @@ import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( AttentionConfig, + AFDConfig, CacheConfig, CompilationConfig, ConfigType, @@ -73,7 +74,7 @@ from vllm.config.model import ( RunnerOption, TokenizerMode, ) -from vllm.config.multimodal import MMCacheType, MMEncoderTPMode +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.observability import DetailedTraceModules from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy from vllm.config.scheduler import SchedulerPolicy @@ -574,6 +575,8 @@ class EngineArgs: CacheConfig.kv_offloading_backend ) tokens_only: bool = False + # AFD config + afd_config: AFDConfig | None = None def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -1153,6 +1156,8 @@ class EngineArgs: "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] ) vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"]) + vllm_group.add_argument("--afd-config", **vllm_kwargs["afd_config"]) + vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) @@ -1732,6 +1737,7 @@ class EngineArgs: profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, + afd_config=self.afd_config, ) return config diff --git a/vllm/entrypoints/afd_ffn_server.py b/vllm/entrypoints/afd_ffn_server.py new file mode 100644 index 0000000000000..27aa6c7ddd344 --- /dev/null +++ b/vllm/entrypoints/afd_ffn_server.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM AFD FFN Server Entry Point + +This script provides a standalone entry point for running FFN servers in an AFD +(Attention-FFN Disaggregation) setup. FFN servers handle remote FFN computation +for attention workers. + +Usage: + python -m vllm.entrypoints.afd_ffn_server /path/to/model \ + --tensor-parallel-size 8 \ + --afd-config '{"afd_connector": "dummy", "afd_role": "ffn"}' \ +""" + +import threading +from typing import Any + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.logger import init_logger +from vllm.utils.argparse_utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +class AFDFFNServer: + """AFD FFN Server main class.""" + + def __init__(self, args: Any): + engine_args = AsyncEngineArgs.from_cli_args(args) + self.vllm_config = engine_args.create_engine_config() + logger.info("Start AFD FFN Server with vllm_config: %s", self.vllm_config) + + def start(self) -> None: + """Start the AFD FFN server.""" + try: + # Import here to avoid circular imports + from vllm.v1.executor.abstract import Executor + + # Create configurations + executor_class = Executor.get_class(self.vllm_config) + self.model_executor = executor_class(vllm_config=self.vllm_config) + # Start the FFN server loop + self._run_server_loop() + + except Exception as e: + logger.error("Failed to start AFD FFN server: %s", e) + raise + + def _run_server_loop(self) -> None: + """Start FFN workers and wait for completion""" + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise + + +def main(args: Any) -> None: + """Main entry point for AFD FFN server.""" + try: + server = AFDFFNServer(args) + server.start() + except KeyboardInterrupt: + logger.info("Interrupted by user") + except Exception as e: + logger.error("Server error: %s", e) + raise + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + # Add model as positional argument (like vllm serve) + parser.add_argument("model", type=str, help="Model name or path") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + # Set the model from positional argument + args.model = args.model + + main(args) diff --git a/vllm/entrypoints/cli/fserver.py b/vllm/entrypoints/cli/fserver.py new file mode 100644 index 0000000000000..087cc9d280d7f --- /dev/null +++ b/vllm/entrypoints/cli/fserver.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM AFD FFN Server CLI command.""" + +import argparse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.afd_ffn_server import main +from vllm.entrypoints.cli.types import CLISubcommand + + +class FServerCommand(CLISubcommand): + """Command for running vLLM AFD FFN Server.""" + + def __init__(self): + self.name = "fserver" + super().__init__() + + def subparser_init(self, subparsers): + """Initialize the fserver subparser.""" + parser = subparsers.add_parser( + self.name, + help="Start vLLM AFD FFN Server", + description="Start vLLM AFD FFN Server for Attention-FFN Disaggregation", + usage="vllm fserver MODEL --afd-config CONFIG [options]", + ) + + # Add model as positional argument (like vllm serve) + parser.add_argument("model", type=str, help="Model name or path") + + # Use AsyncEngineArgs to add all vLLM engine arguments + parser = AsyncEngineArgs.add_cli_args(parser) + + return parser + + def validate(self, args: argparse.Namespace) -> None: + """Validate arguments for fserver command.""" + # Validate that afd_config is provided + if not hasattr(args, "afd_config") or not args.afd_config: + raise ValueError("--afd-config is required for FFN server") + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the fserver command.""" + # Call the main function from afd_ffn_server directly with parsed args + main(args) + + +def cmd_init() -> list[CLISubcommand]: + """Initialize fserver command.""" + return [FServerCommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index a3e73eb7a4c9d..65bb51b2fd950 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -16,6 +16,7 @@ logger = init_logger(__name__) def main(): import vllm.entrypoints.cli.benchmark.main import vllm.entrypoints.cli.collect_env + import vllm.entrypoints.cli.fserver import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.run_batch import vllm.entrypoints.cli.serve @@ -25,6 +26,7 @@ def main(): CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, + vllm.entrypoints.cli.fserver, vllm.entrypoints.cli.benchmark.main, vllm.entrypoints.cli.collect_env, vllm.entrypoints.cli.run_batch, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 77c7253aef06e..3337360e83910 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -195,7 +195,6 @@ def run_multi_api_server(args: argparse.Namespace): assert external_dp_lb or hybrid_dp_lb or dp_rank == 0 api_server_manager: APIServerProcessManager | None = None - with launch_core_engines( vllm_config, executor_class, log_stats, num_api_servers ) as (local_engine_manager, coordinator, addresses): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bc8855a76e2a2..4be0537d34415 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1325,7 +1325,6 @@ async def run_server_worker( listen_address, sock, args, client_config=None, **uvicorn_kwargs ) -> None: """Run a single API server worker.""" - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7a569ec32eac9..9fc8da8b1787a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, NamedTuple import torch @@ -12,7 +12,9 @@ import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig +from vllm.distributed.afd_transfer import AFDConnectorBase from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices @@ -94,6 +96,46 @@ class DPMetadata: # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: list[int] | None = None + @staticmethod + def num_stage_tokens_across_dp( + num_stage_tokens: list[int], dp_size: int, dp_rank: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather the stage token counts across all DP ranks. + Args: + num_stage_tokens: list of token counts per stage for current rank + dp_size: number of DP ranks + dp_rank: current DP rank + Returns: + stage_tokens_across_dp_cpu: [num_stages, dp_size] tensor + max_stage_tokens_across_dp_cpu: [num_stages] tensor with max + tokens per stage + """ + import torch.distributed as dist + + from vllm.distributed.parallel_state import get_dp_group + from vllm.platforms import current_platform + + device = current_platform.device_type + group = get_dp_group().device_group + + num_stages = len(num_stage_tokens) + stage_tokens_across_dp = torch.zeros( + (num_stages, dp_size), device=device, dtype=torch.int32 + ) + stage_tokens_across_dp[:, dp_rank] = torch.tensor( + num_stage_tokens, device=device, dtype=torch.int32 + ) + + # AllReduce to gather from all ranks + dist.all_reduce(stage_tokens_across_dp, group=group) + stage_tokens_across_dp_cpu = stage_tokens_across_dp.cpu() + + # Compute max tokens per stage + max_stage_tokens_across_dp_cpu = torch.max(stage_tokens_across_dp_cpu, dim=1)[0] + + return stage_tokens_across_dp_cpu, max_stage_tokens_across_dp_cpu + @staticmethod def make( parallel_config: ParallelConfig, @@ -182,6 +224,23 @@ class DPMetadata: return torch.cumsum(num_tokens_across_sp_cpu, dim=0) +@dataclass +class AFDMetadata: + afd_tokens_start_loc: list[int] + afd_reqs_start_loc: list[int] + afd_stage_idx: int + afd_connector: "AFDConnectorBase" + afd_tokens_lens: list[int] # padded lengths for tensor slicing + num_of_stages: int + + input_ids_list: list[torch.Tensor] = field(default_factory=list) + positions_list: list[torch.Tensor] = field(default_factory=list) + inputs_embeds_list: list[torch.Tensor] = field(default_factory=list) + intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list) + attn_metadata_list: list[AttentionMetadata] = field(default_factory=list) + dp_metadata_list: list[DPMetadata] = field(default_factory=list) + + @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context @@ -198,6 +257,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: DPMetadata | None = None + afd_metadata: AFDMetadata | None = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE @@ -235,6 +295,7 @@ def create_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + afd_metadata: AFDMetadata | None = None, ): return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, @@ -244,6 +305,7 @@ def create_forward_context( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, + afd_metadata=afd_metadata, ) @@ -272,6 +334,7 @@ def set_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, + afd_metadata: AFDMetadata | None = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -317,6 +380,7 @@ def set_forward_context( cudagraph_runtime_mode, batch_descriptor, ubatch_slices, + afd_metadata=afd_metadata, ) try: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 402f0bf69ceaa..da6f148702a87 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -594,7 +594,6 @@ class ColumnParallelLinear(LinearBase): # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b22cdb6d6c80c..2b386495f0a5f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -45,6 +45,7 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul @@ -101,6 +102,9 @@ elif current_platform.is_xpu(): logger = init_logger(__name__) +from vllm.forward_context import AFDMetadata +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield + class DeepseekAttention(nn.Module): """Normal MHA implementation used by Deepseek v1.""" @@ -383,7 +387,6 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states ) - return final_hidden_states.view(num_tokens, hidden_dim) @@ -1123,6 +1126,8 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config + afd_config = vllm_config.afd_config + self.afd_role = afd_config.afd_role if afd_config is not None else None self.hidden_size = config.hidden_size max_position_embeddings = getattr(config, "max_position_embeddings", 8192) moe_layer_freq = getattr(config, "moe_layer_freq", 1) @@ -1148,42 +1153,47 @@ class DeepseekV2DecoderLayer(nn.Module): attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - vllm_config=vllm_config, - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - v_head_dim=v_head_dim, - q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=kv_lora_rank, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - topk_indices_buffer=topk_indices_buffer, - ) - - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % moe_layer_freq == 0 - ): - self.mlp = DeepseekV2MoE( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = attn_cls( + vllm_config=vllm_config, config=config, - parallel_config=parallel_config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") + else None, + kv_lora_rank=kv_lora_rank, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, ) + + if self.afd_role is None or self.afd_role == "ffn": + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % moe_layer_freq == 0 + ): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -1227,6 +1237,9 @@ class DeepseekV2DecoderLayer(nn.Module): # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: @@ -1239,6 +1252,49 @@ class DeepseekV2DecoderLayer(nn.Module): return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> torch.Tensor: # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1.0 / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + return hidden_states, residual + + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1.0 / self.routed_scaling_factor + return hidden_states + @support_torch_compile class DeepseekV2Model(nn.Module): @@ -1293,6 +1349,112 @@ class DeepseekV2Model(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + llama_4_scaling: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + for layer in islice(self.layers, self.start_layer, self.end_layer): + afd_connector = afd_metadata.afd_connector + afd_metadata.afd_stage_idx = dbo_current_ubatch_id() + + if layer.layer_idx > 0: + hidden_states = afd_connector.recv_ffn_output() + + current_hidden, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=afd_metadata.afd_stage_idx, + seq_len=current_hidden.shape[0], + dtype=current_hidden.dtype, + device=current_hidden.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(current_hidden, metadata) + + if dbo_enabled(): + dbo_yield() + + hidden_states = afd_connector.recv_ffn_output() + + return hidden_states, residual + + def forward_with_afd_v2( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + llama_4_scaling: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() + + ubatch_hidden_states = [] + ubatch_residual = [] + + start_idx = 0 + for pos in afd_metadata.positions_list: + # DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2 + # Otherwise use shape[0] as requested + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None + ) + start_idx = end_idx + + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] + + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states = afd_connector.recv_ffn_output() + else: + hidden_states = ubatch_hidden_states[stage_i] + + current_positions = afd_metadata.positions_list[stage_i] + hidden_states, residual = layer( + current_positions, hidden_states, residual, llama_4_scaling + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output() + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -1325,10 +1487,18 @@ class DeepseekV2Model(nn.Module): else: llama_4_scaling = None - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer( - positions, hidden_states, residual, llama_4_scaling + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata != None: + hidden_states, residual = self.forward_with_afd_v2( + hidden_states, residual, positions, afd_metadata, llama_4_scaling ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1338,6 +1508,12 @@ class DeepseekV2Model(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, hidden_states, layer_idx + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class DeepseekV2MixtureOfExperts(MixtureOfExperts): moe_mlp_layers: list[DeepseekV2MoE] @@ -1404,6 +1580,10 @@ class DeepseekV2ForCausalLM( if self.use_mha: self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the @@ -1452,12 +1632,16 @@ class DeepseekV2ForCausalLM( continue assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): + if (self.afd_role is None or self.afd_role == "ffn") and isinstance( + layer.mlp, DeepseekV2MoE + ): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp self.moe_mlp_layers.append(layer.mlp) self.moe_layers.append(layer.mlp.experts) + if self.afd_role == "attention": + return self.extract_moe_parameters(example_moe) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -1475,6 +1659,12 @@ class DeepseekV2ForCausalLM( ) return hidden_states + def compute_ffn_output( + self, hidden_states, current_layer_idx + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, @@ -1518,6 +1708,13 @@ class DeepseekV2ForCausalLM( # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + if self.afd_role == "attention": + vllm_config = get_current_vllm_config() + num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts + ) + else: + num_redundant_experts = self.num_redundant_experts expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -1528,7 +1725,7 @@ class DeepseekV2ForCausalLM( if rocm_aiter_moe_shared_expert_enabled else 0 ), - num_redundant_experts=self.num_redundant_experts, + num_redundant_experts=num_redundant_experts, ) params_dict = dict(self.named_parameters()) @@ -1537,6 +1734,9 @@ class DeepseekV2ForCausalLM( if "rotary_emb.inv_freq" in name: continue + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model @@ -1640,7 +1840,8 @@ class DeepseekV2ForCausalLM( # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True - + if self.afd_role is not None and self.afd_role == "attention": + continue # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = chunk_name.replace(weight_name, param_name) @@ -1670,6 +1871,12 @@ class DeepseekV2ForCausalLM( loaded_params.add(name_mapped) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank @@ -1698,6 +1905,29 @@ class DeepseekV2ForCausalLM( return loaded_params + def is_moe_weight(self, name): + if ( + "shared_experts" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed_tokens" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + # or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache + return True + return False + class DeepseekForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 7077f1a22e8d7..6e87aeab49b3e 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -11,12 +11,16 @@ from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata, +) +from vllm.forward_context import AFDMetadata, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3TextConfig +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield from .interfaces import SupportsPP from .utils import ( @@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module): config: Step3TextConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + afd_config: AFDConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.afd_role = afd_config.afd_role if afd_config is not None else None - self.self_attn = Step3TextAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - norm_eps=config.rms_norm_eps, - max_position_embedding=config.max_position_embedding, - head_dim=config.head_dim, - share_q_dim=config.share_q_dim, - rope_parameters=config.rope_parameters, - prefix=f"{prefix}.self_attn", - ) - - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) - moe_layers_enum = getattr(config, "moe_layers_enum", None) - if moe_layers_enum is not None: - moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] - else: - # Default to 1dense. - moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] - - if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock( - config=config, quant_config=quant_config, prefix=f"{prefix}.moe" - ) - self.share_expert = Step3TextMLP( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = Step3TextAttention( hidden_size=self.hidden_size, - intermediate_size=config.share_expert_dim, - hidden_act="silu", + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.share_expert", + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", ) - self.use_moe = True - else: - self.mlp = Step3TextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.use_moe = False + + self.layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + + if self.afd_role is None or self.afd_role == "ffn": + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if self.layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.use_moe = True + else: + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.use_moe = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module): hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + if self.use_moe: share_output = self.share_expert(hidden_states) moe_output = self.moe(hidden_states) @@ -309,6 +322,16 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + return hidden_states + @support_torch_compile class Step3TextModel(nn.Module): @@ -317,6 +340,7 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -336,6 +360,7 @@ class Step3TextModel(nn.Module): config=config, cache_config=cache_config, quant_config=quant_config, + afd_config=afd_config, prefix=prefix, ), prefix=f"{prefix}.layers", @@ -352,6 +377,72 @@ class Step3TextModel(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() + + ubatch_hidden_states = [] + ubatch_residual = [] + + start_idx = 0 + for pos in afd_metadata.positions_list: + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None + ) + start_idx = end_idx + + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] + + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states = afd_connector.recv_ffn_output() + else: + hidden_states = ubatch_hidden_states[stage_i] + + current_positions = afd_metadata.positions_list[stage_i] + hidden_states, residual = layer( + current_positions, hidden_states, residual + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output() + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -370,8 +461,19 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata is not None: + hidden_states, residual = self.forward_with_afd( + hidden_states, + residual, + positions, + afd_metadata, + ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -384,6 +486,14 @@ class Step3TextModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, + hidden_states, + layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class Step3TextForCausalLM(nn.Module, SupportsPP): def __init__( @@ -398,6 +508,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.config = config self.vllm_config = vllm_config + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: @@ -429,6 +544,14 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -477,9 +600,13 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if any( disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params @@ -498,6 +625,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): param_name, weight_name, shard_id = mapping if weight_name not in name: continue + + if self.afd_role is not None and self.afd_role == "attention": + continue + name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -521,6 +652,12 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): loaded_params.add(name) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue for ( param_name, weight_name, @@ -552,3 +689,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + def is_moe_weight(self, name): + if ( + "shared_expert" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + return True + return False diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 3c965721b9dae..4154c0f839a19 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.language_model.compute_ffn_output( + hidden_states, current_layer_idx + ) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5f8883c164b3e..0c5434682375c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -105,6 +105,10 @@ class EngineCore: if executor_fail_callback is not None: self.model_executor.register_failure_callback(executor_fail_callback) + self.afd_config = vllm_config.afd_config + if self.afd_config and self.afd_config.afd_role == "ffn": + return + self.available_gpu_memory_for_kv_cache = -1 # Setup KV Caches and update CacheConfig after profiling. @@ -614,6 +618,7 @@ class EngineCoreProc(EngineCore): executor_fail_callback = lambda: self.input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b"") ) + self.afd_config = vllm_config.afd_config self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") @@ -868,7 +873,6 @@ class EngineCoreProc(EngineCore): set_process_title("EngineCore") decorate_logs() engine_core = EngineCoreProc(*args, **kwargs) - engine_core.run_busy_loop() except SystemExit: @@ -891,6 +895,23 @@ class EngineCoreProc(EngineCore): def run_busy_loop(self): """Core busy loop of the EngineCore.""" + if self.afd_config and self.afd_config.afd_role == "ffn": + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise + # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. @@ -1205,6 +1226,7 @@ class DPEngineCoreProc(EngineCoreProc): # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank + self.afd_config = vllm_config.afd_config super().__init__( vllm_config, local_client, @@ -1287,6 +1309,22 @@ class DPEngineCoreProc(EngineCoreProc): def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" + if self.afd_config and self.afd_config.afd_role == "ffn": + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise # Loop until process is sent a SIGINT or SIGTERM while True: diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 24bf66c42f312..ec0805102d2ac 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -16,7 +16,7 @@ import msgspec import zmq from vllm import envs -from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy @@ -908,6 +908,7 @@ def launch_core_engines( vllm_config.cache_config, local_engine_manager, coordinator.proc if coordinator else None, + vllm_config.afd_config, ) @@ -919,6 +920,7 @@ def wait_for_engine_startup( cache_config: CacheConfig, proc_manager: CoreEngineProcManager | None, coord_process: Process | None, + afd_config: AFDConfig | None = None, ): # Wait for engine core process(es) to send ready messages. local_count = parallel_config.data_parallel_size_local @@ -1020,6 +1022,13 @@ def wait_for_engine_startup( conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 engine.state = CoreEngineState.CONNECTED + elif ( + status == "READY" + and engine.state == CoreEngineState.CONNECTED + and afd_config + and afd_config.afd_role == "ffn" + ): + engine.state = CoreEngineState.READY elif status == "READY" and engine.state == CoreEngineState.CONNECTED: # Setup KV cache config with initialization state from # engine core process. Sum values from all engines in DP case. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 649875fe8b7c1..c112980a95a3e 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -213,6 +213,9 @@ class MultiprocExecutor(Executor): self.output_rank = self._get_output_rank() + self.afd_config = self.vllm_config.afd_config + self.afd_role = self.afd_config.afd_role if self.afd_config else None + def start_worker_monitor(self, inline=False) -> None: workers = self.workers self_ref = weakref.ref(self) @@ -565,6 +568,9 @@ class WorkerProc: # environment variable overrides after this point) enable_envs_cache() + self.afd_config = vllm_config.afd_config + self.afd_role = self.afd_config.afd_role if self.afd_config else None + @staticmethod def make_worker_process( vllm_config: VllmConfig, diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py new file mode 100644 index 0000000000000..79ba7e6eb5c2e --- /dev/null +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -0,0 +1,450 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.distributed.afd_transfer.afd_connector.factory import AFDConnectorFactory +from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_world_group, + graph_capture, +) +from vllm.forward_context import set_forward_context, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + +logger = init_logger(__name__) + + +class GPUFFNModelRunner(LoRAModelRunnerMixin): + def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.dtype = self.model_config.dtype + self.load_config = vllm_config.load_config + + self.afd_config = vllm_config.afd_config + if not self.afd_config or not self.afd_config.is_ffn_server: + raise ValueError( + "AFD config must be provided with afd_role='ffn' for FFN server" + ) + + self._counter = 0 + + # Initialize torch.profile for performance monitoring + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 * 27 + 4000 * 27 * 2, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + "./profiler_logs/ffn" + ), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + + # Initialize CUDA graph support + self.use_cuda_graph = not self.model_config.enforce_eager + + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes) + ) + + # Storage for captured graphs + self._cuda_graphs: dict[ + tuple[int, int], torch.cuda.CUDAGraph + ] = {} # {(layer_idx, num_tokens): CUDAGraph} + self._graph_memory_pool = None + + assert self.afd_config.is_ffn_server + self.connector = AFDConnectorFactory.create_connector( + get_world_group().rank, get_world_group().local_rank, self.vllm_config + ) + + if getattr(self.model_config.hf_config, "text_config", None) is not None: + self.num_layers = self.model_config.hf_config.text_config.num_hidden_layers + else: + self.num_layers = self.model_config.hf_config.num_hidden_layers + + def get_model(self) -> nn.Module: + return self.model + + def initialize_afd_connector(self) -> None: + self.connector.init_afd_connector() + + def load_model(self, **kwargs) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: # noqa: SIM117 + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config + ) + else: + logger.info("Model was already initialized. Loading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) + + logger.info("AFD FFN Model loaded successfully") + + def _get_current_layer_idx(self) -> int: + return (self._counter // self.afd_config.num_afd_stages) % self.num_layers + + @torch.inference_mode() + def execute_model(self, scheduler_output=None, intermediate_tensors=None): + """Execute FFN computation for a single request""" + # scheduler_output and intermediate_tensors are unused in FFN server + # mode + self.profiler.step() + + try: + hidden_states, recv_metadata = self.connector.recv_attn_output() + if hasattr(self.connector, "dp_metadata_list"): + dp_metadata = self.connector.dp_metadata_list.get( + recv_metadata.stage_idx, None + ) + else: + dp_metadata = None + current_layer_idx = recv_metadata.layer_idx + # logger.info( + # f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" + # f" dp_metadata: {dp_metadata}" + # ) + num_tokens = hidden_states.shape[0] + if recv_metadata is not None and recv_metadata.recv_handle_list is not None: + for work in recv_metadata.recv_handle_list: + work.wait() + # Try to use CUDA graph if available + cuda_graph_info = self._find_cuda_graph(current_layer_idx, num_tokens) + if cuda_graph_info is not None: + # Use captured CUDA graph for computation + with set_forward_context( + attn_metadata=None, vllm_config=self.vllm_config + ): + get_forward_context().dp_metadata = dp_metadata + rank_ffn_output = self._execute_with_cuda_graph( + hidden_states, cuda_graph_info + ) + else: + # Fallback to eager mode + with set_forward_context( + attn_metadata=None, vllm_config=self.vllm_config + ): + get_forward_context().dp_metadata = dp_metadata + rank_ffn_output = self._execute_eager_mode( + hidden_states, current_layer_idx + ) + + recv_metadata.recv_handle_list = None + self.connector.send_ffn_output(rank_ffn_output, recv_metadata) + except Exception as e: + raise ValueError(f"Error computing FFN: {e}") from e + finally: + self._counter += 1 + if self._counter == self.num_layers * self.afd_config.num_afd_stages: + self._counter = 0 + return None # FFN server doesn't return ModelRunnerOutput + + def _execute_with_cuda_graph( + self, hidden_states: torch.Tensor, cuda_graph_info: dict + ): + """Execute FFN computation using captured CUDA graph.""" + graph = cuda_graph_info["graph"] + input_tensor = cuda_graph_info["input_hidden_states"] + output_tensor = cuda_graph_info["output"] + + # Copy input data to graph's input tensor + # Handle padding if necessary + actual_tokens = hidden_states.shape[0] + graph_tokens = input_tensor.shape[0] + + if actual_tokens <= graph_tokens: + # Copy actual data and pad with zeros if needed + input_tensor[:actual_tokens].copy_(hidden_states) + if actual_tokens < graph_tokens: + input_tensor[actual_tokens:].zero_() + else: + raise ValueError( + f"Input size {actual_tokens} exceeds graph capacity {graph_tokens}" + ) + + # Replay the captured graph + graph.replay() + + # Return only the actual output (without padding) + return output_tensor[:actual_tokens].clone() + + def _execute_eager_mode( + self, + hidden_states: torch.Tensor, + current_layer_idx: int, + recv_metadata: AFDConnectorMetadata = None, + ): + """Execute FFN computation in eager mode (fallback).""" + # Step the profiler for performance monitoring + + # Handle TP case: all-gather tensors from all TP ranks + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size > 1: + # All-gather hidden states from all TP ranks + gathered_hidden_states = tensor_model_parallel_all_gather( + hidden_states, dim=0 + ) + ffn_output = self.model.compute_ffn_output( + gathered_hidden_states, current_layer_idx + ) + # Extract the output corresponding to current rank + start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() + end_idx = start_idx + hidden_states.shape[0] + rank_ffn_output = ffn_output[start_idx:end_idx, :] + else: + # Single TP case + rank_ffn_output = self.model.compute_ffn_output( + hidden_states, current_layer_idx + ) + + return rank_ffn_output + + # Methods required for interface compatibility with GPUModelRunner + def profile_run(self) -> None: + """FFN servers don't need profiling.""" + pass + + def get_kv_cache_spec(self) -> dict[str, "KVCacheSpec"]: + """FFN servers don't use KV cache.""" + return {} + + def initialize_kv_cache(self, kv_cache_config: "KVCacheConfig") -> None: + """FFN servers don't use KV cache.""" + pass + + def _dummy_run(self, num_tokens: int = 1, **kwargs) -> torch.Tensor: + """FFN servers don't need dummy runs.""" + # Return a dummy tensor for interface compatibility + return torch.zeros( + num_tokens, + self.model_config.hf_config.hidden_size, + dtype=self.dtype, + device=self.device, + ) + + def capture_model(self) -> int: + """Capture CUDA graphs for FFN operations.""" + if not self.use_cuda_graph: + logger.warning("Skipping CUDA graph capture.") + return 0 + + logger.info("Starting CUDA graph capture for FFN operations...") + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Create memory pool for graphs + if self._graph_memory_pool is None: + self._graph_memory_pool = torch.cuda.graph_pool_handle() + + # Capture graphs for each layer and different batch sizes + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for layer_idx in range(self.num_layers): + for num_tokens in reversed(self.cudagraph_batch_sizes): + with set_forward_context( + attn_metadata=None, vllm_config=self.vllm_config + ): + self._capture_graph_for_layer_and_size(layer_idx, num_tokens) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + + logger.info( + "FFN CUDA graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) + return cuda_graph_size + + def _capture_graph_for_layer_and_size(self, layer_idx: int, num_tokens: int): + """Capture CUDA graph for specific layer and number of tokens.""" + # Create dummy hidden states + dummy_hidden_states = torch.randn( + num_tokens, + self.model_config.hf_config.hidden_size, + dtype=self.dtype, + device=self.device, + ) + + # Warm up the operations for this specific layer + for _ in range(self.vllm_config.compilation_config.cudagraph_num_of_warmups): + self._run_ffn_computation( + dummy_hidden_states, layer_idx=layer_idx, capture_mode=True + ) + + # Create and capture the graph + graph = torch.cuda.CUDAGraph() + + # Start graph capture + with torch.cuda.graph(graph, pool=self._graph_memory_pool): + output = self._run_ffn_computation( + dummy_hidden_states, layer_idx=layer_idx, capture_mode=True + ) + + # Store the captured graph with layer and token count as key + self._cuda_graphs[(layer_idx, num_tokens)] = { + "graph": graph, + "input_hidden_states": dummy_hidden_states, + "output": output, + } + + logger.debug( + "Captured CUDA graph for layer %s with %s tokens", layer_idx, num_tokens + ) + + def _run_ffn_computation( + self, + hidden_states: torch.Tensor, + layer_idx: int | None = None, + capture_mode: bool = False, + ): + """Run FFN computation for graph capture or replay.""" + if layer_idx is None: + current_layer_idx = self._get_current_layer_idx() if not capture_mode else 0 + else: + current_layer_idx = layer_idx + + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size > 1: + # Handle TP case: all-gather tensors from all TP ranks + gathered_hidden_states = tensor_model_parallel_all_gather( + hidden_states, dim=0 + ) + ffn_output = self.model.compute_ffn_output( + gathered_hidden_states, current_layer_idx + ) + + # Extract the output corresponding to current rank + start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() + end_idx = start_idx + hidden_states.shape[0] + rank_ffn_output = ffn_output[start_idx:end_idx, :] + else: + # Single TP case + rank_ffn_output = self.model.compute_ffn_output( + hidden_states, current_layer_idx + ) + + return rank_ffn_output + + def _find_cuda_graph(self, layer_idx: int, num_tokens: int): + """Find the smallest graph that can handle the given layer and + number of tokens.""" + if not self.use_cuda_graph: + return None + + # Find the minimum capture size that can handle num_tokens for this + # layer + for capture_size in self.cudagraph_batch_sizes: + if num_tokens <= capture_size: + return self._cuda_graphs.get((layer_idx, capture_size)) + return None + + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + """FFN servers don't use samplers.""" + pass + + def update_config(self, overrides: dict[str, Any]) -> None: + """Update configuration for FFN model runner.""" + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " + f"Allowed configs: {allowed_config_names}" + ) + config = getattr(self, config_name) + from vllm.config import update_config + + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + + def reload_weights(self) -> None: + """Reload model weights for FFN model runner.""" + assert getattr(self, "model", None) is not None, ( + "Cannot reload weights before model is loaded." + ) + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) + + @property + def lora_config(self): + """FFN servers don't support LoRA.""" + return None + + @property + def is_pooling_model(self) -> bool: + """FFN servers are not pooling models.""" + return False + + def _dummy_pooler_run(self, hidden_states: torch.Tensor): + """FFN servers don't have poolers.""" + pass + + def get_supported_tasks(self): + """Get supported tasks for FFN model runner.""" + return [] + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + """Get number of input tokens for FFN model runner.""" + return num_scheduled_tokens + + def take_draft_token_ids(self, **kwargs): + """FFN servers don't support draft tokens.""" + pass + + @property + def eplb_state(self): + """FFN servers don't have EPLB state.""" + return None + + def ensure_kv_transfer_shutdown(self): + """FFN servers don't need KV transfer shutdown.""" + pass + + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + """FFN servers don't support tensorized model saving.""" + raise NotImplementedError("FFN servers don't support tensorized model saving") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 414ae33c6251f..6a9b46b478297 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -37,6 +37,7 @@ from vllm.config import ( get_layers_from_vllm_config, update_config, ) +from vllm.distributed.afd_transfer import AFDConnectorFactory from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group @@ -45,11 +46,13 @@ from vllm.distributed.parallel_state import ( get_dcp_group, get_pp_group, get_tp_group, + get_world_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model, ) from vllm.forward_context import ( + AFDMetadata, BatchDescriptor, set_forward_context, ) @@ -547,6 +550,16 @@ class GPUModelRunner( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + + # init AFD config + self.afd_config = vllm_config.afd_config + if self.afd_config and self.afd_config.afd_role == "attention": + self.afd_connector = AFDConnectorFactory.create_connector( + get_world_group().rank, get_world_group().local_rank, vllm_config + ) + self.afd_connector.init_afd_connector() + self.num_stages = self.afd_config.num_afd_stages + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() self.kv_sharing_fast_prefill_logits_indices = None @@ -606,6 +619,25 @@ class GPUModelRunner( self.kv_connector_output: KVConnectorOutput | None = None self.layerwise_nvtx_hooks_registered = False + profile_dir = ( + "./profiler_logs/attn" + if self.afd_config is not None and self.afd_config.afd_role == "attention" + else "./profiler_logs/normal" + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 + 4000, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -1303,7 +1335,7 @@ class GPUModelRunner( assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - + logger.info(f"nums_reqs: {num_reqs}") # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) @@ -2877,6 +2909,7 @@ class GPUModelRunner( # decoder. allow_dp_padding = ( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + or self.afd_config is not None ) should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( @@ -2958,6 +2991,38 @@ class GPUModelRunner( pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True + def _build_afd_metadata( + self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int + ): + afd_metadata = None + if self.afd_config: + # For prefill, compute tokens per stage based on actual token + # counts + afd_tokens_start_loc = [0] + afd_tokens_lens = [] + if ubatch_slices and len(ubatch_slices) > 1: + afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] + afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] + logger.info( + f"afd_tokens_start_loc: {afd_tokens_start_loc} " + f"afd_reqs_start_loc: {afd_reqs_start_loc} " + f"ubatch_slices: {ubatch_slices}" + ) + afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] + else: + afd_tokens_start_loc = [0] + afd_reqs_start_loc = [0] + afd_tokens_lens = [num_tokens_unpadded] + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + num_of_stages=len(ubatch_slices) if ubatch_slices else 1, + ) + return afd_metadata + @torch.inference_mode() def execute_model( self, @@ -3130,6 +3195,11 @@ class GPUModelRunner( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) + + self.profiler.step() # Run the model. # Use persistent buffers for CUDA graphs. with ( @@ -3141,10 +3211,15 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + afd_metadata=afd_metadata, ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): + if input_ids is not None: + logger.info(f"input_ids: {input_ids.shape}") + if inputs_embeds is not None: + logger.info(f"inputs_embeds: {inputs_embeds.shape}") model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -4269,6 +4344,10 @@ class GPUModelRunner( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) + with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( @@ -4279,6 +4358,7 @@ class GPUModelRunner( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + afd_metadata=afd_metadata, ), ): outputs = self.model( @@ -4814,7 +4894,6 @@ class GPUModelRunner( kv_cache_spec, kv_cache_group_id, ) - attn_groups.append(attn_group) return attn_groups @@ -5491,6 +5570,11 @@ class GPUModelRunner( kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + def initialize_afd_connector(self) -> None: + """Initialize AFD connector if available.""" + if hasattr(self, "afd_connector") and self.afd_connector: + self.afd_connector.init_afd_connector() + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index af09129e67b1e..7a44b6cbd42b9 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import ( + AFDMetadata, DPMetadata, create_forward_context, get_forward_context, @@ -126,7 +127,10 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None - if vllm_config.parallel_config.enable_expert_parallel: + if ( + vllm_config.parallel_config.enable_expert_parallel + and not vllm_config.afd_config + ): # Currently only DeepEP highthroughput supports SM control so this # only affects that case. ep_group = get_ep_group() @@ -303,6 +307,7 @@ class UBatchWrapper: dp_metadata, batch_descriptor, cudagraph_runtime_mode, + afd_metadata, ) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] @@ -314,6 +319,7 @@ class UBatchWrapper: dp_metadata=dp_metadata[i], batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode, + afd_metadata=afd_metadata, ) ) @@ -353,6 +359,62 @@ class UBatchWrapper: return ubatch_metadata + def _make_afd_ubatch_metadata( + self, + ubatch_slices, + attn_metadata, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + dp_metadata, + afd_metadata, + ) -> AFDMetadata: + if ubatch_slices is None: + afd_metadata.input_ids_list.append(input_ids) + afd_metadata.positions_list.append(positions) + afd_metadata.inputs_embeds_list.append(inputs_embeds) + afd_metadata.intermediate_tensors_list.append(intermediate_tensors) + afd_metadata.attn_metadata_list.append(attn_metadata) + afd_metadata.dp_metadata_list.append(dp_metadata) + else: + for i, ubatch_slice in enumerate(ubatch_slices): + ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) = self._slice_model_inputs( + ubatch_slice.token_slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ) + + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata = DPMetadata.make( + self.vllm_config.parallel_config, + ubatch_slice.num_tokens, + ubatch_num_tokens_across_dp, + ) + + afd_metadata.input_ids_list.append(sliced_input_ids) + afd_metadata.positions_list.append(sliced_positions) + afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) + afd_metadata.intermediate_tensors_list.append( + sliced_intermediate_tensors + ) + afd_metadata.attn_metadata_list.append( + attn_metadata[i] if attn_metadata is not None else None + ) + afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) + + return afd_metadata + def _slice_model_inputs( self, tokens_slice: slice, @@ -385,6 +447,34 @@ class UBatchWrapper: batch_descriptor = forward_context.batch_descriptor ubatch_slices = forward_context.ubatch_slices cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + afd_metadata = forward_context.afd_metadata + + attn_metadata = forward_context.attn_metadata + input_ids = kwargs["input_ids"] + positions = kwargs["positions"] + intermediate_tensors = kwargs["intermediate_tensors"] + inputs_embeds = kwargs["inputs_embeds"] + compute_stream = torch.cuda.current_stream() + + dp_metadata = forward_context.dp_metadata + + if self.vllm_config.afd_config: + afd_metadata = self._make_afd_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + dp_metadata=dp_metadata, + afd_metadata=afd_metadata, + ) + forward_context.afd_metadata = afd_metadata + if cudagraph_runtime_mode is CUDAGraphMode.NONE: + return self.runnable(*args, **kwargs) + else: + assert self.cudagraph_wrapper is not None + return self.cudagraph_wrapper(*args, **kwargs) # If there's no ubatching, just run the runnable object if ubatch_slices is None: @@ -394,6 +484,7 @@ class UBatchWrapper: # num_tokens, we don't have a non-ubatched one. Without this # check, the cudagraph wrapper will try to capture a cudagraph # for this shape during a normal run. + if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None if batch_descriptor.num_tokens in self.cudagraphs: @@ -405,18 +496,9 @@ class UBatchWrapper: assert self.cudagraph_wrapper is not None return self.cudagraph_wrapper(*args, **kwargs) - attn_metadata = forward_context.attn_metadata num_tokens = ( ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start ) * 2 - input_ids = kwargs["input_ids"] - positions = kwargs["positions"] - intermediate_tensors = kwargs["intermediate_tensors"] - inputs_embeds = kwargs["inputs_embeds"] - compute_stream = torch.cuda.current_stream() - - dp_metadata = forward_context.dp_metadata - # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None ubatch_dp_metadata = [] @@ -448,6 +530,7 @@ class UBatchWrapper: dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, + afd_metadata=afd_metadata, ) with self.sm_control: return self._capture_ubatches(ubatch_metadata, self.model) @@ -470,6 +553,7 @@ class UBatchWrapper: dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, + afd_metadata=afd_metadata, ) with self.sm_control: return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 68fe0853370f7..1c35633f5fa2f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,7 +6,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext from types import NoneType -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Optional, cast import numpy as np import torch @@ -52,6 +52,8 @@ from vllm.v1.outputs import ( ModelRunnerOutput, ) from vllm.v1.utils import report_usage_stats +from vllm.v1.worker.gpu_ffn_model_runner import GPUFFNModelRunner +from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager @@ -248,8 +250,11 @@ class Worker(WorkerBase): num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 init_workspace_manager(self.device, num_ubatches) + self.model_runner: GPUModelRunner | GPUFFNModelRunner | GPUModelRunnerV2 + if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server: + self.model_runner = GPUFFNModelRunner(self.vllm_config, self.device) # Construct the model runner - if self.use_v2_model_runner: + elif self.use_v2_model_runner: from vllm.v1.worker.gpu.model_runner import ( GPUModelRunner as GPUModelRunnerV2, ) @@ -574,8 +579,16 @@ class Worker(WorkerBase): @torch.inference_mode() def execute_model( - self, scheduler_output: "SchedulerOutput" - ) -> ModelRunnerOutput | None: + self, scheduler_output: Optional["SchedulerOutput"] = None + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: + # FFN server mode: direct execution without pipeline parallelism + if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server: + return self.model_runner.execute_model(scheduler_output) + + if scheduler_output is None: + raise ValueError("scheduler_output is required in normal inference mode") + + # Normal inference mode intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -682,6 +695,53 @@ class Worker(WorkerBase): # worker will always be healthy as long as it's running. return + def start_ffn_server_loop(self) -> None: + """Start FFN server loop for AFD FFN workers""" + if not ( + self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server + ): + return + + self.model_runner.capture_model() + self.model_runner.initialize_afd_connector() + + if self.profiler: + self.profiler.start() + for _ in range(1000): # FIXME: hardcoded profiler iterations + self.model_runner.execute_model(scheduler_output=None) + torch.cuda.synchronize() # Ensure GPU operations complete + self.profiler.stop() + print(self.profiler.key_averages().table(sort_by="self_cuda_time_total")) + + import threading + + self._ffn_shutdown_event = threading.Event() + + def ffn_worker_loop(): + # Set CUDA device for this thread (thread-local context) + torch.cuda.set_device(self.device) + logger.info("FFN worker loop started") + + try: + while not self._ffn_shutdown_event.is_set(): + # Execute FFN computation + self.model_runner.execute_model(scheduler_output=None) + except Exception as e: + logger.error("FFN worker loop error: %s", e) + raise + + self._ffn_thread = threading.Thread(target=ffn_worker_loop, daemon=True) + self._ffn_thread.start() + logger.info("FFN server loop started in worker") + + def stop_ffn_server_loop(self) -> None: + """Stop FFN server loop""" + if hasattr(self, "_ffn_shutdown_event"): + self._ffn_shutdown_event.set() + if hasattr(self, "_ffn_thread"): + self._ffn_thread.join(timeout=5) + logger.info("FFN server loop stopped") + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: from vllm.distributed.parallel_state import get_ep_group