Merge e7254d8994a4caf49e4cd08b604657b7ee8ae418 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
jiangkuaixue123 2025-12-25 00:06:54 +00:00 committed by GitHub
commit 92d38e41c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 2614 additions and 118 deletions

View File

@ -0,0 +1,19 @@
# P2P Connector
P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances.
When the --enable-dbo flag is currently enabled, the num_stage parameter becomes ineffective, and the actual number of microbatches is 2.
Currently, the P2PConnector only supports scenarios where the number of dies of A equals that of F. Asymmetric configurations will be supported in future updates.
1. Attn
```
vllm serve "/path/to/DeepSeek-V2-Lite" --data_parallel_size=2 --enable_expert_parallel --enforce_eager --enable-dbo --dbo-prefill-token-threshold 12 --dbo-decode-token-threshold 2 --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"2","afd_extra_config":{"afd_size":"2A2F"}}'
```
2. FFN
```
vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"2", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}'
```

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.attention import AttentionConfig
from vllm.config.afd import AFDConfig
from vllm.config.cache import CacheConfig
from vllm.config.compilation import (
CompilationConfig,
@ -66,6 +67,8 @@ __all__ = [
"KVEventsConfig",
# From vllm.config.kv_transfer
"KVTransferConfig",
# AFD (Attention FFN Disaggregation) configuration
"AFDConfig",
# From vllm.config.load
"LoadConfig",
# From vllm.config.lora

79
vllm/config/afd.py Normal file
View File

@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import Any, Literal
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class AFDConfig:
"""Configuration for AFD (Attention FFN Disaggregation) distributed
computation."""
afd_connector: str = "dummy"
"""The AFD connector for vLLM to communicate between attention and FFN
nodes. Available connectors: 'dummy', 'p2pconnector'"""
afd_role: Literal["attention", "ffn"] = "attention"
"""Role of this vLLM instance in AFD. 'attention' for attention workers,
'ffn' for FFN servers."""
afd_port: int = 1239
"""Port number for stepmesh parameter server communication."""
afd_host: str = "127.0.0.1"
"""Host address for stepmesh parameter server communication."""
num_afd_stages: int = 3
"""Number of pipeline stages for stage parallelism."""
num_attention_servers: int = 1
"""Number of attention servers."""
num_ffn_servers: int = 1
"""Number of FFN servers."""
afd_server_rank: int = 0
"""Rank of this AFD server."""
afd_extra_config: dict[str, Any] = field(default_factory=dict)
"""Extra configuration for specific AFD connectors."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# AFD configuration affects the computation graph structure
# as it changes how FFN computation is performed
factors: list[Any] = [
self.afd_connector,
self.afd_role,
self.num_afd_stages,
self.num_attention_servers,
self.num_ffn_servers,
]
return hashlib.sha256(str(factors).encode()).hexdigest()
@property
def is_attention_server(self) -> bool:
"""Check if this instance is configured as an attention server."""
return self.afd_role == "attention"
@property
def is_ffn_server(self) -> bool:
"""Check if this instance is configured as an FFN server."""
return self.afd_role == "ffn"

View File

@ -553,7 +553,7 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from vllm.v1.executor import ray_utils
backend: DistributedExecutorBackend = "mp"

View File

@ -28,6 +28,7 @@ from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .attention import AttentionConfig
from .afd import AFDConfig
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
from .device import DeviceConfig
@ -227,6 +228,8 @@ class VllmConfig:
"""The configurations for event publishing."""
ec_transfer_config: ECTransferConfig | None = None
"""The configurations for distributed EC cache transfer."""
afd_config: AFDConfig | None = None
"""AFD (Attention FFN Disaggregation) configuration."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@ -318,6 +321,10 @@ class VllmConfig:
vllm_factors.append(self.ec_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.afd_config:
vllm_factors.append(self.afd_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = safe_hash(
@ -872,16 +879,17 @@ class VllmConfig:
if self.parallel_config.use_ubatching:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in [
"deepep_low_latency",
"deepep_high_throughput",
], (
"Microbatching currently only supports the deepep_low_latency and "
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
"supported. To fix use --all2all-backend=deepep_low_latency or "
"--all2all-backend=deepep_high_throughput and install the DeepEP"
" kernels."
)
if self.afd_config is None:
assert a2a_backend in [
"deepep_low_latency",
"deepep_high_throughput",
], (
"Microbatching currently only supports the deepep_low_latency and "
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
"supported. To fix use --all2all-backend=deepep_low_latency or "
"--all2all-backend=deepep_high_throughput and install the DeepEP"
" kernels."
)
if not self.model_config.disable_cascade_attn:
self.model_config.disable_cascade_attn = True

View File

@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""AFD (Attention-FFN Disaggregation) transfer components.
This module provides the distributed infrastructure for AFD, enabling
disaggregated FFN computation across different machines while keeping
attention computation local.
"""
from .afd_connector import AFDConnectorBase, AFDConnectorFactory, AFDConnectorMetadata
__all__ = ["AFDConnectorBase", "AFDConnectorMetadata", "AFDConnectorFactory"]

View File

@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""AFD connector implementations for different transport backends."""
from .base import AFDConnectorBase
from .factory import AFDConnectorFactory
from .metadata import AFDConnectorMetadata
__all__ = [
"AFDConnectorBase",
"AFDConnectorFactory",
"AFDConnectorMetadata",
]

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
AFDConnectorBase Class for Distributed AFD FFN computation
The class provides the four core AFD communication interfaces:
1. send_attn_output(): Send attention output to FFN servers (Attention Worker)
2. recv_ffn_output(): Receive FFN computation result (Attention Worker)
3. recv_attn_output(): Receive attention output from workers (FFN Server)
4. send_ffn_output(): Send FFN computation result back (FFN Server)
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import torch
if TYPE_CHECKING:
from vllm.config import VllmConfig
from .metadata import AFDConnectorMetadata
class AFDConnectorBase(ABC):
"""
Abstract base class for AFD connectors.
This provides the four core interfaces for AFD communication between
attention workers and FFN servers.
"""
@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
"""Initialize the AFD connector.
Args:
rank: Global rank of this process
local_rank: Local rank within the node
config: VllmConfig containing AFDConfig
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the connector and release resources."""
raise NotImplementedError
@abstractmethod
def init_afd_connector(self) -> None:
"""Initialize the AFD connector."""
raise NotImplementedError
@property
@abstractmethod
def is_initialized(self) -> bool:
"""Check if the connector is initialized and ready to use.
Returns:
bool: True if the connector is initialized, False otherwise.
"""
raise NotImplementedError
def get_connector_rank(self) -> int:
"""Get the rank of this connector."""
return getattr(self, "rank", 0)
def get_connector_local_rank(self) -> int:
"""Get the local rank of this connector."""
return getattr(self, "local_rank", 0)
@abstractmethod
def send_attn_output(
self,
hidden_states: torch.Tensor,
metadata: "AFDConnectorMetadata",
) -> Any:
"""Send attention output to FFN servers.
Args:
hidden_states: Attention output tensor
metadata: AFD metadata containing layer_idx, stage_idx, seq_len info
Returns:
Any: Handle for tracking this request (backend-specific)
"""
raise NotImplementedError
@abstractmethod
def recv_ffn_output(
self,
handle: Any,
) -> torch.Tensor:
"""Wait for and receive FFN computation result.
Args:
handle: Handle returned by send_attn_output()
Returns:
torch.Tensor: FFN computation result
"""
raise NotImplementedError
@abstractmethod
def recv_attn_output(
self,
timeout_ms: int | None = None,
) -> tuple[torch.Tensor, "AFDConnectorMetadata"]:
"""Receive attention output from attention workers.
Args:
timeout_ms: Optional timeout in milliseconds
Returns:
tuple: (hidden_states, metadata)
- hidden_states: Concatenated attention outputs
- metadata: Inferred AFD metadata containing
seq_lens and other info
"""
raise NotImplementedError
@abstractmethod
def send_ffn_output(
self,
ffn_output: torch.Tensor,
metadata: "AFDConnectorMetadata",
) -> None:
"""Send FFN computation result back to attention workers.
Args:
ffn_output: Computed FFN result
metadata: AFD metadata containing seq_lens
for splitting and routing info
"""
raise NotImplementedError

View File

@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Dummy AFD Connector for testing and local development.
This connector provides a no-op AFDConnectorBase interface,
useful for testing and development scenarios where actual
distributed FFN computation is not needed.
"""
import time
from collections import deque
from typing import TYPE_CHECKING, Any
import torch
from vllm.logger import init_logger
from .base import AFDConnectorBase
from .metadata import AFDConnectorMetadata
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class DummyAFDConnector(AFDConnectorBase):
"""Dummy AFD connector that returns zero tensors.
This connector is useful for:
1. Testing AFD infrastructure without actual remote computation
2. Development scenarios where FFN computation should be disabled
3. Fallback behavior when remote FFN servers are unavailable
"""
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
"""Initialize the dummy AFD connector.
Args:
rank: Global rank of this process
local_rank: Local rank within the node
config: VllmConfig containing AFDConfig
"""
self.afd_config = config.afd_config
self.rank = rank
self.local_rank = local_rank
self._is_initialized = False
self.hidden_size = config.model_config.hf_config.hidden_size
self.num_stages = config.afd_config.num_afd_stages
self.events: deque = deque(maxlen=self.num_stages)
logger.info("DummyAFDConnector initialized for rank %s", rank)
self.init_afd_connector()
def init_afd_connector(self) -> None:
"""Initialize the dummy connector.
This is a no-op for the dummy connector.
"""
if self._is_initialized:
return
logger.info("Initializing DummyAFDConnector (no-op)")
self._is_initialized = True
def close(self) -> None:
"""Close the dummy connector.
This is a no-op for the dummy connector.
"""
if not self._is_initialized:
return
logger.info("Closing DummyAFDConnector (no-op)")
self._is_initialized = False
@property
def is_initialized(self) -> bool:
"""Check if the connector is initialized.
Returns:
bool: True if initialized, False otherwise
"""
return self._is_initialized
def send_attn_output(
self,
hidden_states: torch.Tensor,
metadata: AFDConnectorMetadata,
) -> Any:
"""
Send attention output to FFN servers (dummy implementation).
"""
logger.debug(
"DummyAFDConnector: send_attn_output layer=%s, stage=%s",
metadata.layer_idx,
metadata.stage_idx,
)
# Validate metadata consistency
if not metadata.validate_tensor_shape(hidden_states.shape):
raise ValueError(
"Tensor shape %s doesn't match metadata %s",
hidden_states.shape,
metadata,
)
if not metadata.is_single_sequence:
raise ValueError("Attention side should have single sequence")
self.events.append((None, metadata))
return None
def recv_ffn_output(
self,
timeout_ms: float | None = None,
) -> torch.Tensor:
"""Receive FFN computation result (dummy implementation)."""
logger.debug("DummyAFDConnector: recv_ffn_output timeout_ms=%s", timeout_ms)
_, metadata = self.events.popleft()
seq_len = metadata.seq_lens[0] # Single sequence for attention side
return torch.zeros(
seq_len,
self.hidden_size,
dtype=metadata.dtype,
device=metadata.device,
)
def recv_attn_output(
self,
timeout_ms: int | None = None,
) -> tuple[torch.Tensor, AFDConnectorMetadata]:
"""
Receive attention output from attention workers (dummy implementation).
"""
logger.debug("DummyAFDConnector: recv_attn_output timeout_ms=%s", timeout_ms)
# Generate dummy data that simulates multiple attention workers
dummy_seq_lens = [
2,
2,
2,
] # Variable sequence lengths from different workers
total_tokens = sum(dummy_seq_lens)
dummy_tensor = torch.zeros(
total_tokens, self.hidden_size, dtype=torch.bfloat16, device="cuda"
)
# Create dummy metadata
dummy_metadata = AFDConnectorMetadata.create_ffn_metadata(
layer_idx=0, # Dummy layer
stage_idx=0, # Dummy stage
dtype=torch.bfloat16,
device=torch.device("cuda"),
seq_lens=dummy_seq_lens,
request_id=f"dummy_ffn_batch_{time.time()}",
)
# Cache metadata for send_ffn_output
self._current_metadata = dummy_metadata
time.sleep(1)
return dummy_tensor, dummy_metadata
def send_ffn_output(
self,
ffn_output: torch.Tensor,
metadata: AFDConnectorMetadata,
) -> None:
"""Send FFN computation result back (dummy implementation)."""
logger.debug(
"DummyAFDConnector: send_ffn_output layer=%s, stage=%s",
metadata.layer_idx,
metadata.stage_idx,
)
# Validate that ffn_output shape matches metadata
if not metadata.validate_tensor_shape(ffn_output.shape):
logger.warning(
"FFN output shape %s doesn't match metadata %s",
ffn_output.shape,
metadata,
)
# Log the splitting information for debugging
logger.debug(
"DummyAFDConnector: Split FFN output into %s parts with lengths %s",
metadata.num_sequences,
metadata.seq_lens,
)
# Simulate splitting (for logging purposes)
if metadata.get_split_indices():
split_outputs = torch.split(ffn_output, metadata.seq_lens, dim=0)
logger.debug(
"DummyAFDConnector: Split shapes: %s",
[s.shape for s in split_outputs],
)
time.sleep(1)
# No-op for dummy connector - just log the operation

View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Factory for creating AFD connectors based on configuration."""
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from .base import AFDConnectorBase
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class AFDConnectorFactory:
_registry: dict[str, Callable[[], type[AFDConnectorBase]]] = {}
@classmethod
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> type[AFDConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(
cls, rank: int, local_rank: int, config: "VllmConfig"
) -> AFDConnectorBase:
"""Create an AFD connector based on the configuration.
Args:
rank: Global rank of this process
local_rank: Local rank within the node
config: VllmConfig containing AFDConfig
Returns:
AFDConnectorBase: The created connector instance
Raises:
ValueError: If the transport backend is not supported
ImportError: If required dependencies are not available
"""
afd_config = config.afd_config
connector_name = afd_config.afd_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, AFDConnectorBase)
return connector_cls(rank, local_rank, config)
@classmethod
def get_connector_class(cls, connector_name: str) -> type[AFDConnectorBase]:
"""Get the connector class for a given connector name.
Args:
connector_name: The connector name
Returns:
type[AFDConnectorBase]: The connector class
Raises:
ValueError: If the connector name is not supported
"""
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
return cls._registry[connector_name]()
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
AFDConnectorFactory.register_connector(
"dummy",
"vllm.distributed.afd_transfer.afd_connector.dummy_connector",
"DummyAFDConnector",
)
AFDConnectorFactory.register_connector(
"p2pconnector",
"vllm.distributed.afd_transfer.afd_connector.p2p_connector",
"P2PAFDConnector",
)

View File

@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""AFD metadata definitions for communication between attention and
FFN workers."""
import time
import typing
from dataclasses import dataclass, field
from typing import Any
import torch
class FFNNeedForwardData:
def __init__(
self,
moe_comm_method: typing.Any,
num_input_tokens: int,
with_prefill: bool,
total_num_scheduled_tokens: int | None,
is_dummy_run: bool = False,
):
self.moe_comm_method = moe_comm_method
self.num_input_tokens = num_input_tokens
self.with_prefill = with_prefill
self.total_num_scheduled_tokens = total_num_scheduled_tokens
self.is_dummy_run = is_dummy_run
@dataclass
class AFDConnectorMetadata:
"""Lightweight AFD metadata containing core information needed for
communication."""
layer_idx: int
stage_idx: int
seq_lens: list[int] # Length of each sequence, supports variable length and
# multiple sequences
dtype: torch.dtype
device: torch.device
topk_idx: torch.Tensor | None = None # indices token which expert to be sended
topk_weights: torch.Tensor | None = None # the expert weights
moe_expert_num: int | None = None # number of moe experts
shared_expert_num: int | None = None # number of share experts
scale: torch.Tensor | None = None # quant scale
expertTokenNumsOut: torch.Tensor | None = (
None # The number of tokens received by each expert is used as input for the subsequent GMM.
)
recv_handle_list: list[Any] | None = (
None # the communication handles (list of Work objects returned by torch.distributed.irecv)
)
# Optional fields for debugging and extensibility
request_id: str | None = None
timestamp: float | None = None
"""ffn need forward data"""
ffn_need_forward_data: FFNNeedForwardData | None = None
num_of_stages: int = 1
afd_tokens_lens: list = field(default_factory=list)
def __post_init__(self):
"""Validate data consistency."""
if not self.seq_lens:
raise ValueError("seq_lens cannot be empty")
if any(length <= 0 for length in self.seq_lens):
raise ValueError("All sequence lengths must be positive")
@property
def total_tokens(self) -> int:
"""Total number of tokens."""
return sum(self.seq_lens)
@property
def num_sequences(self) -> int:
"""Number of sequences."""
return len(self.seq_lens)
@property
def is_single_sequence(self) -> bool:
"""Whether this is a single sequence (attention side characteristic)."""
return len(self.seq_lens) == 1
@property
def is_multi_sequence(self) -> bool:
"""Whether this is multiple sequences (FFN side characteristic)."""
return len(self.seq_lens) > 1
@classmethod
def create_attention_metadata(
cls,
layer_idx: int,
stage_idx: int,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
request_id: str | None = None,
ffn_need_forward_data: FFNNeedForwardData | None = None,
num_of_stages: int = 1,
afd_tokens_lens: list[int] = [],
) -> "AFDConnectorMetadata":
"""Create metadata for attention side (single sequence)."""
return cls(
layer_idx=layer_idx,
stage_idx=stage_idx,
seq_lens=[seq_len],
dtype=dtype,
device=device,
request_id=request_id,
ffn_need_forward_data=ffn_need_forward_data,
timestamp=time.time(),
num_of_stages=num_of_stages,
afd_tokens_lens=afd_tokens_lens,
)
@classmethod
def create_ffn_metadata(
cls,
layer_idx: int,
stage_idx: int,
seq_lens: list[int],
dtype: torch.dtype,
device: torch.device,
request_id: str | None = None,
) -> "AFDConnectorMetadata":
"""Create metadata for FFN side (multiple sequences)."""
return cls(
layer_idx=layer_idx,
stage_idx=stage_idx,
seq_lens=seq_lens.copy(), # Prevent external modification
dtype=dtype,
device=device,
request_id=request_id,
timestamp=time.time(),
)
def get_split_indices(self) -> list[int]:
"""Get tensor split indices for FFN side output splitting."""
if len(self.seq_lens) <= 1:
return []
indices = []
cumsum = 0
for length in self.seq_lens[:-1]: # Exclude the last one
cumsum += length
indices.append(cumsum)
return indices
def validate_tensor_shape(self, tensor_shape: tuple[int, ...]) -> bool:
"""Validate if tensor shape is consistent with metadata."""
if len(tensor_shape) < 1:
return False
return tensor_shape[0] == self.total_tokens
def to_dict(self) -> dict:
"""Convert to dictionary format for serialization and debugging."""
return {
"layer_idx": self.layer_idx,
"stage_idx": self.stage_idx,
"seq_lens": self.seq_lens,
"dtype": self.dtype,
"device": self.device,
"total_tokens": self.total_tokens,
"num_sequences": self.num_sequences,
"request_id": self.request_id,
"timestamp": self.timestamp,
}
def __repr__(self) -> str:
"""Friendly string representation."""
return (
f"AFDConnectorMetadata(layer={self.layer_idx}, "
f"stage={self.stage_idx}, seq_lens={self.seq_lens}, "
f"total_tokens={self.total_tokens}, dtype={self.dtype}, "
f"device={self.device}, request_id={self.request_id})"
)

View File

@ -0,0 +1,337 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import re
from datetime import timedelta
import torch
from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
init_afd_process_group,
init_model_parallel_group,
)
from vllm.logger import init_logger
from vllm.forward_context import (
DPMetadata,
get_forward_context,
)
from .base import AFDConnectorBase
from .metadata import AFDConnectorMetadata
logger = init_logger(__name__)
class DefaultProcessGroupSwitcher:
def __init__(self, default_group, new_default_group):
self.default_group = default_group
self.new_default_group = new_default_group
def __enter__(self):
_update_default_pg(self.new_default_group)
def __exit__(self, exc_type, exc_value, traceback):
_update_default_pg(self.default_group)
class P2PAFDConnector(AFDConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
) -> None:
self.rank = rank
self.local_rank = local_rank
self.config = config
self._initialized: bool = False
self._need_recv_metadata: bool = True
self._tensor_metadata_list: dict[int, TensorMetadata] = {}
self._current_afd_connector_metadata: AFDConnectorMetadata | None = None
if getattr(self.config.model_config.hf_config, "text_config", None) is not None:
self.num_hidden_layers: int = (
self.config.model_config.hf_config.text_config.num_hidden_layers
)
else:
self.num_hidden_layers: int = (
self.config.model_config.hf_config.num_hidden_layers
)
self.recv_attn_output_counter: int = 0
self.recv_ffn_output_counter: int = 0
self.dp_metadata_list: dict[int, DPMetadata] = {}
def close(self) -> None:
"""Close the connector and release resources."""
# TODO: Implement proper resource clean up if needed.
pass
def init_afd_connector(self) -> None:
"""Initialize the AFD connector."""
afd_size = self.config.afd_config.afd_extra_config.get("afd_size")
role = self.config.afd_config.afd_role
attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups())
world_rank = self.rank if role == "attention" else self.rank + attn_size
afd_pg = init_afd_process_group(
backend="nccl",
init_method=(
f"tcp://{self.config.afd_config.afd_host}"
f":{self.config.afd_config.afd_port}"
),
world_size=ffn_size + attn_size,
rank=world_rank,
group_name="afd",
timeout=timedelta(minutes=2),
)
# Construct rank lists for sub groups.
# Each group contains one attention and one ffn rank.
ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)]
attn_ranks = [i for i in range(attn_size)]
assert len(ffn_ranks) == len(attn_ranks), (
"ffn_ranks and attn_ranks must have the same length"
)
default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg)
with default_pg_switcher:
sub_group_ranks = []
for i in range(len(ffn_ranks)):
ranks = [attn_ranks[i], ffn_ranks[i]]
sub_group_ranks.append(ranks)
# Create two independent groups:
# a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn)
# e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn)
# The communication domain (rank range) is the same, but different group_name
# creates independent groups.
self.a2e_group = init_model_parallel_group(
sub_group_ranks,
self.local_rank,
backend="nccl",
group_name="a2e",
)
self.e2a_group = init_model_parallel_group(
sub_group_ranks,
self.local_rank,
backend="nccl",
group_name="e2a",
)
self._initialized = True
def is_initialized(self) -> bool:
"""Check if the connector is initialized and ready to use.
Returns:
bool: True if the connector is initialized, False otherwise.
"""
return self._initialized
def _build_tensor_metadata_list(
self,
tensor_metadata: TensorMetadata,
connector_metadata: AFDConnectorMetadata,
) -> dict[int, TensorMetadata]:
tensor_metadata_list = {}
num_of_stages = connector_metadata.num_of_stages
for idx in range(num_of_stages):
if idx == 0:
tensor_metadata_list[0] = tensor_metadata
else:
new_size = list(tensor_metadata.size)
new_size[0] = connector_metadata.afd_tokens_lens[idx]
tensor_metadata_list[idx] = TensorMetadata(
tensor_metadata.device,
tensor_metadata.dtype,
torch.Size(new_size),
)
return tensor_metadata_list
def _send_metadata(
self,
metadata: AFDConnectorMetadata,
hidden_states: torch.Tensor,
dst: int,
process_group: GroupCoordinator,
) -> None:
if not torch.distributed.is_initialized() or process_group.world_size == 1:
return []
assert dst < process_group.world_size, f"Invalid dst rank ({dst})"
tensor_metadata = TensorMetadata(
hidden_states.device.type, hidden_states.dtype, hidden_states.size()
)
metadata_tuple = (metadata, tensor_metadata)
process_group.send_object(metadata_tuple, dst=dst)
self._tensor_metadata_list = self._build_tensor_metadata_list(
tensor_metadata, metadata
)
def _recv_metadata(
self,
src: int,
process_group: GroupCoordinator,
) -> None:
(self._current_afd_connector_metadata, tensor_metadata) = (
process_group.recv_object(src=src)
)
self._tensor_metadata_list = self._build_tensor_metadata_list(
tensor_metadata, self._current_afd_connector_metadata
)
if self.config.parallel_config.data_parallel_size > 1:
logger.info(
"jcz recv_metadata num_of_stages:{}".format(
self._current_afd_connector_metadata.num_of_stages
)
)
for stage_idx in range(self._current_afd_connector_metadata.num_of_stages):
num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0]
self.dp_metadata_list[stage_idx] = DPMetadata.make(
self.config.parallel_config,
num_tokens_per_ubatch,
torch.tensor(
[num_tokens_per_ubatch]
* self.config.parallel_config.data_parallel_size,
device="cpu",
dtype=torch.int32,
),
)
logger.info(
"jcz recv_metadata self.dp_metadata_list:{}".format(
self.dp_metadata_list
)
)
def _send_hidden_states(
self,
hidden_states: torch.Tensor,
dst: int,
process_group: GroupCoordinator,
) -> None:
if not torch.distributed.is_initialized() or process_group.world_size == 1:
return []
assert dst < process_group.world_size, f"Invalid dst rank ({dst})"
assert not hidden_states.is_cpu, "Hidden states must be on GPU"
torch.distributed.send(
hidden_states,
dst=process_group.ranks[dst],
group=process_group.device_group,
)
def _recv_hidden_states(
self,
src: int,
process_group: GroupCoordinator,
tensor_metadata: TensorMetadata,
) -> torch.Tensor:
if not torch.distributed.is_initialized() or process_group.world_size == 1:
return {}, []
assert src < process_group.world_size, f"Invalid src rank ({src})"
hidden_states = torch.empty(
tensor_metadata.size,
dtype=tensor_metadata.dtype,
device=tensor_metadata.device,
)
torch.distributed.recv(
hidden_states,
src=process_group.ranks[src],
group=process_group.device_group,
)
return hidden_states
# -------------------------------------------------------------------------
# attn -> ffn
# -------------------------------------------------------------------------
def send_attn_output(
self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata
) -> None:
"""
Called by ATTN side to send intermediate tensors
generated by ATTN instances to FFN.
"""
try:
dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size
if metadata.layer_idx == 0 and metadata.stage_idx == 0:
self._send_metadata(metadata, hidden_states, dst, self.a2e_group)
self._current_afd_connector_metadata = metadata
self._send_hidden_states(hidden_states, dst, self.a2e_group)
except Exception as e:
raise RuntimeError(f"Communication error: {e}")
def recv_ffn_output(self) -> torch.Tensor:
"""
Called by the ATTN side to receive MOE output intermediate tensors,
possibly dispatching from the receiver to other GPUs.
"""
src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size
stage_idx = (
self.recv_ffn_output_counter
% self._current_afd_connector_metadata.num_of_stages
)
hidden_states = self._recv_hidden_states(
src,
self.e2a_group,
self._tensor_metadata_list[stage_idx],
)
self.recv_ffn_output_counter = (
self.recv_ffn_output_counter + 1
) % self._current_afd_connector_metadata.num_of_stages
return hidden_states
# -------------------------------------------------------------------------
# ffn -> attn
# -------------------------------------------------------------------------
def send_ffn_output(
self,
hidden_states: torch.Tensor,
metadata: AFDConnectorMetadata,
) -> None:
"""
Called by FFN side to send intermediate tensors generated by FFN
instances back to the sender (should be the same GPU as source).
"""
dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size
self._send_hidden_states(hidden_states, dst, self.e2a_group)
self.recv_attn_output_counter += 1
if (
self.recv_attn_output_counter
% (
self._current_afd_connector_metadata.num_of_stages
* self.num_hidden_layers
)
== 0
):
self._need_recv_metadata = True
self.recv_attn_output_counter = 0
def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]:
"""
Called by the FFN side to receive intermediate tensors from ATTN.
Handles receiving and possibly dispatching tensors.
"""
src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size
if self._need_recv_metadata:
self._recv_metadata(src, self.a2e_group)
self._need_recv_metadata = False
stage_idx = (
self.recv_attn_output_counter
% self._current_afd_connector_metadata.num_of_stages
)
layer_idx = (
self.recv_attn_output_counter
// self._current_afd_connector_metadata.num_of_stages
)
hidden_states = self._recv_hidden_states(
src,
self.a2e_group,
self._tensor_metadata_list[stage_idx],
)
self._current_afd_connector_metadata.layer_idx = layer_idx
self._current_afd_connector_metadata.stage_idx = stage_idx
return hidden_states, self._current_afd_connector_metadata

View File

@ -41,6 +41,14 @@ import torch.distributed
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
from torch.distributed import Backend, ProcessGroup
from torch.distributed.distributed_c10d import (
PrefixStore,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
@ -322,7 +330,6 @@ class GroupCoordinator:
self_device_group = None
self_cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
@ -1036,6 +1043,63 @@ _INNER_DP_WORLD: GroupCoordinator | None = None
_NODE_COUNT: int | None = None
def init_afd_process_group(
backend: str | Backend = None,
init_method: str | None = None,
timeout: timedelta | None = None,
world_size: int = -1,
rank: int = -1,
store: Store | None = None,
group_name: str = None,
pg_options: Any | None = None,
):
assert (store is None) or (init_method is None), (
"Cannot specify both init_method and store."
)
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
store = PrefixStore(group_name, store)
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
global _AFD
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
_AFD = pg
return pg
_WORLD: GroupCoordinator | None = None
_NODE_COUNT: int | None = None
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, "world group is not initialized"
return _WORLD

View File

@ -35,6 +35,7 @@ import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
AFDConfig,
CacheConfig,
CompilationConfig,
ConfigType,
@ -73,7 +74,7 @@ from vllm.config.model import (
RunnerOption,
TokenizerMode,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
@ -574,6 +575,8 @@ class EngineArgs:
CacheConfig.kv_offloading_backend
)
tokens_only: bool = False
# AFD config
afd_config: AFDConfig | None = None
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@ -1153,6 +1156,8 @@ class EngineArgs:
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
)
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
vllm_group.add_argument("--afd-config", **vllm_kwargs["afd_config"])
vllm_group.add_argument(
"--optimization-level", **vllm_kwargs["optimization_level"]
)
@ -1732,6 +1737,7 @@ class EngineArgs:
profiler_config=self.profiler_config,
additional_config=self.additional_config,
optimization_level=self.optimization_level,
afd_config=self.afd_config,
)
return config

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM AFD FFN Server Entry Point
This script provides a standalone entry point for running FFN servers in an AFD
(Attention-FFN Disaggregation) setup. FFN servers handle remote FFN computation
for attention workers.
Usage:
python -m vllm.entrypoints.afd_ffn_server /path/to/model \
--tensor-parallel-size 8 \
--afd-config '{"afd_connector": "dummy", "afd_role": "ffn"}' \
"""
import threading
from typing import Any
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import init_logger
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
class AFDFFNServer:
"""AFD FFN Server main class."""
def __init__(self, args: Any):
engine_args = AsyncEngineArgs.from_cli_args(args)
self.vllm_config = engine_args.create_engine_config()
logger.info("Start AFD FFN Server with vllm_config: %s", self.vllm_config)
def start(self) -> None:
"""Start the AFD FFN server."""
try:
# Import here to avoid circular imports
from vllm.v1.executor.abstract import Executor
# Create configurations
executor_class = Executor.get_class(self.vllm_config)
self.model_executor = executor_class(vllm_config=self.vllm_config)
# Start the FFN server loop
self._run_server_loop()
except Exception as e:
logger.error("Failed to start AFD FFN server: %s", e)
raise
def _run_server_loop(self) -> None:
"""Start FFN workers and wait for completion"""
logger.info("AFD FFN Server started, workers running...")
try:
# Tell workers to start FFN server loops (one-time call)
self.model_executor.collective_rpc("start_ffn_server_loop")
# Main thread waits without busy polling
shutdown_event = threading.Event()
shutdown_event.wait() # Block until interrupted
except KeyboardInterrupt:
logger.info("Server shutting down...")
self.model_executor.collective_rpc("stop_ffn_server_loop")
except Exception as e:
logger.error("Server error: %s", e)
raise
def main(args: Any) -> None:
"""Main entry point for AFD FFN server."""
try:
server = AFDFFNServer(args)
server.start()
except KeyboardInterrupt:
logger.info("Interrupted by user")
except Exception as e:
logger.error("Server error: %s", e)
raise
if __name__ == "__main__":
parser = FlexibleArgumentParser()
# Add model as positional argument (like vllm serve)
parser.add_argument("model", type=str, help="Model name or path")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
# Set the model from positional argument
args.model = args.model
main(args)

View File

@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM AFD FFN Server CLI command."""
import argparse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.afd_ffn_server import main
from vllm.entrypoints.cli.types import CLISubcommand
class FServerCommand(CLISubcommand):
"""Command for running vLLM AFD FFN Server."""
def __init__(self):
self.name = "fserver"
super().__init__()
def subparser_init(self, subparsers):
"""Initialize the fserver subparser."""
parser = subparsers.add_parser(
self.name,
help="Start vLLM AFD FFN Server",
description="Start vLLM AFD FFN Server for Attention-FFN Disaggregation",
usage="vllm fserver MODEL --afd-config CONFIG [options]",
)
# Add model as positional argument (like vllm serve)
parser.add_argument("model", type=str, help="Model name or path")
# Use AsyncEngineArgs to add all vLLM engine arguments
parser = AsyncEngineArgs.add_cli_args(parser)
return parser
def validate(self, args: argparse.Namespace) -> None:
"""Validate arguments for fserver command."""
# Validate that afd_config is provided
if not hasattr(args, "afd_config") or not args.afd_config:
raise ValueError("--afd-config is required for FFN server")
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Run the fserver command."""
# Call the main function from afd_ffn_server directly with parsed args
main(args)
def cmd_init() -> list[CLISubcommand]:
"""Initialize fserver command."""
return [FServerCommand()]

View File

@ -16,6 +16,7 @@ logger = init_logger(__name__)
def main():
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.fserver
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
@ -25,6 +26,7 @@ def main():
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.fserver,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,

View File

@ -195,7 +195,6 @@ def run_multi_api_server(args: argparse.Namespace):
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
api_server_manager: APIServerProcessManager | None = None
with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers
) as (local_engine_manager, coordinator, addresses):

View File

@ -1325,7 +1325,6 @@ async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

View File

@ -4,7 +4,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, NamedTuple
import torch
@ -12,7 +12,9 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.distributed.afd_transfer import AFDConnectorBase
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
@ -94,6 +96,46 @@ class DPMetadata:
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: list[int] | None = None
@staticmethod
def num_stage_tokens_across_dp(
num_stage_tokens: list[int], dp_size: int, dp_rank: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather the stage token counts across all DP ranks.
Args:
num_stage_tokens: list of token counts per stage for current rank
dp_size: number of DP ranks
dp_rank: current DP rank
Returns:
stage_tokens_across_dp_cpu: [num_stages, dp_size] tensor
max_stage_tokens_across_dp_cpu: [num_stages] tensor with max
tokens per stage
"""
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
from vllm.platforms import current_platform
device = current_platform.device_type
group = get_dp_group().device_group
num_stages = len(num_stage_tokens)
stage_tokens_across_dp = torch.zeros(
(num_stages, dp_size), device=device, dtype=torch.int32
)
stage_tokens_across_dp[:, dp_rank] = torch.tensor(
num_stage_tokens, device=device, dtype=torch.int32
)
# AllReduce to gather from all ranks
dist.all_reduce(stage_tokens_across_dp, group=group)
stage_tokens_across_dp_cpu = stage_tokens_across_dp.cpu()
# Compute max tokens per stage
max_stage_tokens_across_dp_cpu = torch.max(stage_tokens_across_dp_cpu, dim=1)[0]
return stage_tokens_across_dp_cpu, max_stage_tokens_across_dp_cpu
@staticmethod
def make(
parallel_config: ParallelConfig,
@ -182,6 +224,23 @@ class DPMetadata:
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@dataclass
class AFDMetadata:
afd_tokens_start_loc: list[int]
afd_reqs_start_loc: list[int]
afd_stage_idx: int
afd_connector: "AFDConnectorBase"
afd_tokens_lens: list[int] # padded lengths for tensor slicing
num_of_stages: int
input_ids_list: list[torch.Tensor] = field(default_factory=list)
positions_list: list[torch.Tensor] = field(default_factory=list)
inputs_embeds_list: list[torch.Tensor] = field(default_factory=list)
intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list)
attn_metadata_list: list[AttentionMetadata] = field(default_factory=list)
dp_metadata_list: list[DPMetadata] = field(default_factory=list)
@dataclass
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
@ -198,6 +257,7 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: DPMetadata | None = None
afd_metadata: AFDMetadata | None = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# by default NONE, no cudagraph is used.
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
@ -235,6 +295,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
afd_metadata: AFDMetadata | None = None,
):
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
@ -244,6 +305,7 @@ def create_forward_context(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
afd_metadata=afd_metadata,
)
@ -272,6 +334,7 @@ def set_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
afd_metadata: AFDMetadata | None = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
@ -317,6 +380,7 @@ def set_forward_context(
cudagraph_runtime_mode,
batch_descriptor,
ubatch_slices,
afd_metadata=afd_metadata,
)
try:

View File

@ -594,7 +594,6 @@ class ColumnParallelLinear(LinearBase):
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)

View File

@ -45,6 +45,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
@ -101,6 +102,9 @@ elif current_platform.is_xpu():
logger = init_logger(__name__)
from vllm.forward_context import AFDMetadata
from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield
class DeepseekAttention(nn.Module):
"""Normal MHA implementation used by Deepseek v1."""
@ -383,7 +387,6 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
@ -1123,6 +1126,8 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
afd_config = vllm_config.afd_config
self.afd_role = afd_config.afd_role if afd_config is not None else None
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
@ -1148,42 +1153,47 @@ class DeepseekV2DecoderLayer(nn.Module):
attn_cls = DeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
vllm_config=vllm_config,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=kv_lora_rank,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
topk_indices_buffer=topk_indices_buffer,
)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % moe_layer_freq == 0
):
self.mlp = DeepseekV2MoE(
if self.afd_role is None or self.afd_role == "attention":
self.self_attn = attn_cls(
vllm_config=vllm_config,
config=config,
parallel_config=parallel_config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank")
else None,
kv_lora_rank=kv_lora_rank,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
prefix=f"{prefix}.self_attn",
topk_indices_buffer=topk_indices_buffer,
)
if self.afd_role is None or self.afd_role == "ffn":
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % moe_layer_freq == 0
):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
@ -1227,6 +1237,9 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.afd_role == "attention":
return hidden_states, residual
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
@ -1239,6 +1252,49 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual
def compute_attn_output(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor: # Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
return hidden_states, residual
def compute_ffn_output(self, hidden_states):
assert self.afd_role == "ffn"
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1.0 / self.routed_scaling_factor
return hidden_states
@support_torch_compile
class DeepseekV2Model(nn.Module):
@ -1293,6 +1349,112 @@ class DeepseekV2Model(nn.Module):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward_with_afd(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
afd_metadata: AFDMetadata,
llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
for layer in islice(self.layers, self.start_layer, self.end_layer):
afd_connector = afd_metadata.afd_connector
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
if layer.layer_idx > 0:
hidden_states = afd_connector.recv_ffn_output()
current_hidden, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=afd_metadata.afd_stage_idx,
seq_len=current_hidden.shape[0],
dtype=current_hidden.dtype,
device=current_hidden.device,
num_of_stages=afd_metadata.num_of_stages,
afd_tokens_lens=afd_metadata.afd_tokens_lens,
)
afd_connector.send_attn_output(current_hidden, metadata)
if dbo_enabled():
dbo_yield()
hidden_states = afd_connector.recv_ffn_output()
return hidden_states, residual
def forward_with_afd_v2(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
afd_metadata: AFDMetadata,
llama_4_scaling: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context()
ubatch_hidden_states = []
ubatch_residual = []
start_idx = 0
for pos in afd_metadata.positions_list:
# DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2
# Otherwise use shape[0] as requested
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
end_idx = start_idx + num_tokens
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
ubatch_residual.append(
residual[start_idx:end_idx] if residual is not None else None
)
start_idx = end_idx
for layer in islice(self.layers, self.start_layer, self.end_layer):
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
afd_connector = afd_metadata.afd_connector
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
residual = ubatch_residual[stage_i]
if layer.layer_idx > 0:
hidden_states = afd_connector.recv_ffn_output()
else:
hidden_states = ubatch_hidden_states[stage_i]
current_positions = afd_metadata.positions_list[stage_i]
hidden_states, residual = layer(
current_positions, hidden_states, residual, llama_4_scaling
)
ubatch_hidden_states[stage_i] = hidden_states
ubatch_residual[stage_i] = residual
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=stage_i,
seq_len=hidden_states.shape[0],
dtype=hidden_states.dtype,
device=hidden_states.device,
num_of_stages=afd_metadata.num_of_stages,
afd_tokens_lens=afd_metadata.afd_tokens_lens,
)
afd_connector.send_attn_output(hidden_states, metadata)
# Recv last layer FFN output.
for stage_i in range(afd_metadata.num_of_stages):
ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
# Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
if any(r is not None for r in ubatch_residual):
residual = torch.cat(ubatch_residual, dim=0)
else:
residual = None
return hidden_states, residual
def forward(
self,
input_ids: torch.Tensor,
@ -1325,10 +1487,18 @@ class DeepseekV2Model(nn.Module):
else:
llama_4_scaling = None
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
forward_ctx = get_forward_context()
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
if afd_metadata != None:
hidden_states, residual = self.forward_with_afd_v2(
hidden_states, residual, positions, afd_metadata, llama_4_scaling
)
else:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@ -1338,6 +1508,12 @@ class DeepseekV2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def compute_ffn_output(
self, hidden_states, layer_idx
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states)
return hidden_states
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
moe_mlp_layers: list[DeepseekV2MoE]
@ -1404,6 +1580,10 @@ class DeepseekV2ForCausalLM(
if self.use_mha:
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
self.afd_config = vllm_config.afd_config
self.afd_role = (
self.afd_config.afd_role if self.afd_config is not None else None
)
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
@ -1452,12 +1632,16 @@ class DeepseekV2ForCausalLM(
continue
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
if (self.afd_role is None or self.afd_role == "ffn") and isinstance(
layer.mlp, DeepseekV2MoE
):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
if self.afd_role == "attention":
return
self.extract_moe_parameters(example_moe)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@ -1475,6 +1659,12 @@ class DeepseekV2ForCausalLM(
)
return hidden_states
def compute_ffn_output(
self, hidden_states, current_layer_idx
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
@ -1518,6 +1708,13 @@ class DeepseekV2ForCausalLM(
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
if self.afd_role == "attention":
vllm_config = get_current_vllm_config()
num_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
else:
num_redundant_experts = self.num_redundant_experts
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
@ -1528,7 +1725,7 @@ class DeepseekV2ForCausalLM(
if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts,
num_redundant_experts=num_redundant_experts,
)
params_dict = dict(self.named_parameters())
@ -1537,6 +1734,9 @@ class DeepseekV2ForCausalLM(
if "rotary_emb.inv_freq" in name:
continue
if self.afd_role == "attention" and self.is_moe_weight(name):
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
@ -1640,7 +1840,8 @@ class DeepseekV2ForCausalLM(
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
if self.afd_role is not None and self.afd_role == "attention":
continue
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
@ -1670,6 +1871,12 @@ class DeepseekV2ForCausalLM(
loaded_params.add(name_mapped)
break
else:
if (
self.afd_role == "ffn"
and not self.is_moe_weight(name)
and not self.is_common_weight(name)
):
continue
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
@ -1698,6 +1905,29 @@ class DeepseekV2ForCausalLM(
return loaded_params
def is_moe_weight(self, name):
if (
"shared_experts" in name
or "experts" in name
or "gate" in name
or "up" in name
or "down" in name
):
return True
return False
def is_common_weight(self, name):
if (
"lm_head" in name
or "model.norm.weight" in name
or "embed_tokens" in name
or "input_layernorm" in name
or "post_attention_layernorm" in name
):
# or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache
return True
return False
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
pass

View File

@ -11,12 +11,16 @@ from torch import nn
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.afd_transfer.afd_connector.metadata import (
AFDConnectorMetadata,
)
from vllm.forward_context import AFDMetadata, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.step3_vl import Step3TextConfig
from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield
from .interfaces import SupportsPP
from .utils import (
@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module):
config: Step3TextConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
afd_config: AFDConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.afd_role = afd_config.afd_role if afd_config is not None else None
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
)
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(
config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
)
self.share_expert = Step3TextMLP(
if self.afd_role is None or self.afd_role == "attention":
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
)
self.use_moe = True
else:
self.mlp = Step3TextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.use_moe = False
self.layer_idx = int(prefix.split("layers.")[1].split(".")[0])
if self.afd_role is None or self.afd_role == "ffn":
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if self.layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(
config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
)
self.share_expert = Step3TextMLP(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
)
self.use_moe = True
else:
self.mlp = Step3TextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.use_moe = False
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module):
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.afd_role == "attention":
return hidden_states, residual
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
@ -309,6 +322,16 @@ class Step3TextDecoderLayer(nn.Module):
return hidden_states, residual
def compute_ffn_output(self, hidden_states):
assert self.afd_role == "ffn"
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
hidden_states = share_output + moe_output
else:
hidden_states = self.mlp(hidden_states)
return hidden_states
@support_torch_compile
class Step3TextModel(nn.Module):
@ -317,6 +340,7 @@ class Step3TextModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
afd_config = vllm_config.afd_config
self.vocab_size = config.vocab_size
self.config = config
@ -336,6 +360,7 @@ class Step3TextModel(nn.Module):
config=config,
cache_config=cache_config,
quant_config=quant_config,
afd_config=afd_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
@ -352,6 +377,72 @@ class Step3TextModel(nn.Module):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward_with_afd(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
afd_metadata: AFDMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_conext = get_forward_context()
ubatch_hidden_states = []
ubatch_residual = []
start_idx = 0
for pos in afd_metadata.positions_list:
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
end_idx = start_idx + num_tokens
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
ubatch_residual.append(
residual[start_idx:end_idx] if residual is not None else None
)
start_idx = end_idx
for layer in islice(self.layers, self.start_layer, self.end_layer):
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
afd_connector = afd_metadata.afd_connector
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
residual = ubatch_residual[stage_i]
if layer.layer_idx > 0:
hidden_states = afd_connector.recv_ffn_output()
else:
hidden_states = ubatch_hidden_states[stage_i]
current_positions = afd_metadata.positions_list[stage_i]
hidden_states, residual = layer(
current_positions, hidden_states, residual
)
ubatch_hidden_states[stage_i] = hidden_states
ubatch_residual[stage_i] = residual
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=stage_i,
seq_len=hidden_states.shape[0],
dtype=hidden_states.dtype,
device=hidden_states.device,
num_of_stages=afd_metadata.num_of_stages,
afd_tokens_lens=afd_metadata.afd_tokens_lens,
)
afd_connector.send_attn_output(hidden_states, metadata)
# Recv last layer FFN output.
for stage_i in range(afd_metadata.num_of_stages):
ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
# Re-assemble the batch
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
if any(r is not None for r in ubatch_residual):
residual = torch.cat(ubatch_residual, dim=0)
else:
residual = None
return hidden_states, residual
def forward(
self,
input_ids: torch.Tensor,
@ -370,8 +461,19 @@ class Step3TextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
forward_ctx = get_forward_context()
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
if afd_metadata is not None:
hidden_states, residual = self.forward_with_afd(
hidden_states,
residual,
positions,
afd_metadata,
)
else:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@ -384,6 +486,14 @@ class Step3TextModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def compute_ffn_output(
self,
hidden_states,
layer_idx,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states)
return hidden_states
class Step3TextForCausalLM(nn.Module, SupportsPP):
def __init__(
@ -398,6 +508,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.config = config
self.vllm_config = vllm_config
self.afd_config = vllm_config.afd_config
self.afd_role = (
self.afd_config.afd_role if self.afd_config is not None else None
)
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
@ -429,6 +544,14 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
)
return hidden_states
def compute_ffn_output(
self,
hidden_states,
current_layer_idx,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
@ -477,9 +600,13 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
if self.afd_role == "attention" and self.is_moe_weight(name):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if any(
disable_moe_stacked_param in name
for disable_moe_stacked_param in disable_moe_stacked_params
@ -498,6 +625,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
if self.afd_role is not None and self.afd_role == "attention":
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
@ -521,6 +652,12 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
loaded_params.add(name)
break
else:
if (
self.afd_role == "ffn"
and not self.is_moe_weight(name)
and not self.is_common_weight(name)
):
continue
for (
param_name,
weight_name,
@ -552,3 +689,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def is_moe_weight(self, name):
if (
"shared_expert" in name
or "experts" in name
or "gate" in name
or "up" in name
or "down" in name
):
return True
return False
def is_common_weight(self, name):
if (
"lm_head" in name
or "model.norm.weight" in name
or "embed" in name
or "input_layernorm" in name
or "post_attention_layernorm" in name
):
return True
return False

View File

@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return hidden_states
def compute_ffn_output(
self,
hidden_states,
current_layer_idx,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.language_model.compute_ffn_output(
hidden_states, current_layer_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -105,6 +105,10 @@ class EngineCore:
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(executor_fail_callback)
self.afd_config = vllm_config.afd_config
if self.afd_config and self.afd_config.afd_role == "ffn":
return
self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling.
@ -614,6 +618,7 @@ class EngineCoreProc(EngineCore):
executor_fail_callback = lambda: self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b"")
)
self.afd_config = vllm_config.afd_config
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
@ -868,7 +873,6 @@ class EngineCoreProc(EngineCore):
set_process_title("EngineCore")
decorate_logs()
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
@ -891,6 +895,23 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
if self.afd_config and self.afd_config.afd_role == "ffn":
logger.info("AFD FFN Server started, workers running...")
try:
# Tell workers to start FFN server loops (one-time call)
self.model_executor.collective_rpc("start_ffn_server_loop")
# Main thread waits without busy polling
shutdown_event = threading.Event()
shutdown_event.wait() # Block until interrupted
except KeyboardInterrupt:
logger.info("Server shutting down...")
self.model_executor.collective_rpc("stop_ffn_server_loop")
except Exception as e:
logger.error("Server error: %s", e)
raise
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
@ -1205,6 +1226,7 @@ class DPEngineCoreProc(EngineCoreProc):
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
self.afd_config = vllm_config.afd_config
super().__init__(
vllm_config,
local_client,
@ -1287,6 +1309,22 @@ class DPEngineCoreProc(EngineCoreProc):
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
if self.afd_config and self.afd_config.afd_role == "ffn":
logger.info("AFD FFN Server started, workers running...")
try:
# Tell workers to start FFN server loops (one-time call)
self.model_executor.collective_rpc("start_ffn_server_loop")
# Main thread waits without busy polling
shutdown_event = threading.Event()
shutdown_event.wait() # Block until interrupted
except KeyboardInterrupt:
logger.info("Server shutting down...")
self.model_executor.collective_rpc("stop_ffn_server_loop")
except Exception as e:
logger.error("Server error: %s", e)
raise
# Loop until process is sent a SIGINT or SIGTERM
while True:

View File

@ -16,7 +16,7 @@ import msgspec
import zmq
from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.config import AFDConfig, CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
@ -908,6 +908,7 @@ def launch_core_engines(
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
vllm_config.afd_config,
)
@ -919,6 +920,7 @@ def wait_for_engine_startup(
cache_config: CacheConfig,
proc_manager: CoreEngineProcManager | None,
coord_process: Process | None,
afd_config: AFDConfig | None = None,
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
@ -1020,6 +1022,13 @@ def wait_for_engine_startup(
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif (
status == "READY"
and engine.state == CoreEngineState.CONNECTED
and afd_config
and afd_config.afd_role == "ffn"
):
engine.state = CoreEngineState.READY
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.

View File

@ -213,6 +213,9 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
self.afd_config = self.vllm_config.afd_config
self.afd_role = self.afd_config.afd_role if self.afd_config else None
def start_worker_monitor(self, inline=False) -> None:
workers = self.workers
self_ref = weakref.ref(self)
@ -565,6 +568,9 @@ class WorkerProc:
# environment variable overrides after this point)
enable_envs_cache()
self.afd_config = vllm_config.afd_config
self.afd_role = self.afd_config.afd_role if self.afd_config else None
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,

View File

@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TYPE_CHECKING, Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed.afd_transfer.afd_connector.factory import AFDConnectorFactory
from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_world_group,
graph_capture,
)
from vllm.forward_context import set_forward_context, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
logger = init_logger(__name__)
class GPUFFNModelRunner(LoRAModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
self.dtype = self.model_config.dtype
self.load_config = vllm_config.load_config
self.afd_config = vllm_config.afd_config
if not self.afd_config or not self.afd_config.is_ffn_server:
raise ValueError(
"AFD config must be provided with afd_role='ffn' for FFN server"
)
self._counter = 0
# Initialize torch.profile for performance monitoring
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=6000 * 27 + 4000 * 27 * 2, warmup=1, active=30, repeat=1
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
"./profiler_logs/ffn"
),
record_shapes=True,
profile_memory=False,
with_stack=False,
)
# Initialize CUDA graph support
self.use_cuda_graph = not self.model_config.enforce_eager
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)
)
# Storage for captured graphs
self._cuda_graphs: dict[
tuple[int, int], torch.cuda.CUDAGraph
] = {} # {(layer_idx, num_tokens): CUDAGraph}
self._graph_memory_pool = None
assert self.afd_config.is_ffn_server
self.connector = AFDConnectorFactory.create_connector(
get_world_group().rank, get_world_group().local_rank, self.vllm_config
)
if getattr(self.model_config.hf_config, "text_config", None) is not None:
self.num_layers = self.model_config.hf_config.text_config.num_hidden_layers
else:
self.num_layers = self.model_config.hf_config.num_hidden_layers
def get_model(self) -> nn.Module:
return self.model
def initialize_afd_connector(self) -> None:
self.connector.init_afd_connector()
def load_model(self, **kwargs) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config
)
else:
logger.info("Model was already initialized. Loading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
"Model loading took %.4f GiB and %.6f seconds",
self.model_memory_usage / GiB_bytes,
time_after_load - time_before_load,
)
logger.info("AFD FFN Model loaded successfully")
def _get_current_layer_idx(self) -> int:
return (self._counter // self.afd_config.num_afd_stages) % self.num_layers
@torch.inference_mode()
def execute_model(self, scheduler_output=None, intermediate_tensors=None):
"""Execute FFN computation for a single request"""
# scheduler_output and intermediate_tensors are unused in FFN server
# mode
self.profiler.step()
try:
hidden_states, recv_metadata = self.connector.recv_attn_output()
if hasattr(self.connector, "dp_metadata_list"):
dp_metadata = self.connector.dp_metadata_list.get(
recv_metadata.stage_idx, None
)
else:
dp_metadata = None
current_layer_idx = recv_metadata.layer_idx
# logger.info(
# f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}"
# f" dp_metadata: {dp_metadata}"
# )
num_tokens = hidden_states.shape[0]
if recv_metadata is not None and recv_metadata.recv_handle_list is not None:
for work in recv_metadata.recv_handle_list:
work.wait()
# Try to use CUDA graph if available
cuda_graph_info = self._find_cuda_graph(current_layer_idx, num_tokens)
if cuda_graph_info is not None:
# Use captured CUDA graph for computation
with set_forward_context(
attn_metadata=None, vllm_config=self.vllm_config
):
get_forward_context().dp_metadata = dp_metadata
rank_ffn_output = self._execute_with_cuda_graph(
hidden_states, cuda_graph_info
)
else:
# Fallback to eager mode
with set_forward_context(
attn_metadata=None, vllm_config=self.vllm_config
):
get_forward_context().dp_metadata = dp_metadata
rank_ffn_output = self._execute_eager_mode(
hidden_states, current_layer_idx
)
recv_metadata.recv_handle_list = None
self.connector.send_ffn_output(rank_ffn_output, recv_metadata)
except Exception as e:
raise ValueError(f"Error computing FFN: {e}") from e
finally:
self._counter += 1
if self._counter == self.num_layers * self.afd_config.num_afd_stages:
self._counter = 0
return None # FFN server doesn't return ModelRunnerOutput
def _execute_with_cuda_graph(
self, hidden_states: torch.Tensor, cuda_graph_info: dict
):
"""Execute FFN computation using captured CUDA graph."""
graph = cuda_graph_info["graph"]
input_tensor = cuda_graph_info["input_hidden_states"]
output_tensor = cuda_graph_info["output"]
# Copy input data to graph's input tensor
# Handle padding if necessary
actual_tokens = hidden_states.shape[0]
graph_tokens = input_tensor.shape[0]
if actual_tokens <= graph_tokens:
# Copy actual data and pad with zeros if needed
input_tensor[:actual_tokens].copy_(hidden_states)
if actual_tokens < graph_tokens:
input_tensor[actual_tokens:].zero_()
else:
raise ValueError(
f"Input size {actual_tokens} exceeds graph capacity {graph_tokens}"
)
# Replay the captured graph
graph.replay()
# Return only the actual output (without padding)
return output_tensor[:actual_tokens].clone()
def _execute_eager_mode(
self,
hidden_states: torch.Tensor,
current_layer_idx: int,
recv_metadata: AFDConnectorMetadata = None,
):
"""Execute FFN computation in eager mode (fallback)."""
# Step the profiler for performance monitoring
# Handle TP case: all-gather tensors from all TP ranks
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size > 1:
# All-gather hidden states from all TP ranks
gathered_hidden_states = tensor_model_parallel_all_gather(
hidden_states, dim=0
)
ffn_output = self.model.compute_ffn_output(
gathered_hidden_states, current_layer_idx
)
# Extract the output corresponding to current rank
start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank()
end_idx = start_idx + hidden_states.shape[0]
rank_ffn_output = ffn_output[start_idx:end_idx, :]
else:
# Single TP case
rank_ffn_output = self.model.compute_ffn_output(
hidden_states, current_layer_idx
)
return rank_ffn_output
# Methods required for interface compatibility with GPUModelRunner
def profile_run(self) -> None:
"""FFN servers don't need profiling."""
pass
def get_kv_cache_spec(self) -> dict[str, "KVCacheSpec"]:
"""FFN servers don't use KV cache."""
return {}
def initialize_kv_cache(self, kv_cache_config: "KVCacheConfig") -> None:
"""FFN servers don't use KV cache."""
pass
def _dummy_run(self, num_tokens: int = 1, **kwargs) -> torch.Tensor:
"""FFN servers don't need dummy runs."""
# Return a dummy tensor for interface compatibility
return torch.zeros(
num_tokens,
self.model_config.hf_config.hidden_size,
dtype=self.dtype,
device=self.device,
)
def capture_model(self) -> int:
"""Capture CUDA graphs for FFN operations."""
if not self.use_cuda_graph:
logger.warning("Skipping CUDA graph capture.")
return 0
logger.info("Starting CUDA graph capture for FFN operations...")
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Create memory pool for graphs
if self._graph_memory_pool is None:
self._graph_memory_pool = torch.cuda.graph_pool_handle()
# Capture graphs for each layer and different batch sizes
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for layer_idx in range(self.num_layers):
for num_tokens in reversed(self.cudagraph_batch_sizes):
with set_forward_context(
attn_metadata=None, vllm_config=self.vllm_config
):
self._capture_graph_for_layer_and_size(layer_idx, num_tokens)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
logger.info(
"FFN CUDA graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time,
cuda_graph_size / (1 << 30),
)
return cuda_graph_size
def _capture_graph_for_layer_and_size(self, layer_idx: int, num_tokens: int):
"""Capture CUDA graph for specific layer and number of tokens."""
# Create dummy hidden states
dummy_hidden_states = torch.randn(
num_tokens,
self.model_config.hf_config.hidden_size,
dtype=self.dtype,
device=self.device,
)
# Warm up the operations for this specific layer
for _ in range(self.vllm_config.compilation_config.cudagraph_num_of_warmups):
self._run_ffn_computation(
dummy_hidden_states, layer_idx=layer_idx, capture_mode=True
)
# Create and capture the graph
graph = torch.cuda.CUDAGraph()
# Start graph capture
with torch.cuda.graph(graph, pool=self._graph_memory_pool):
output = self._run_ffn_computation(
dummy_hidden_states, layer_idx=layer_idx, capture_mode=True
)
# Store the captured graph with layer and token count as key
self._cuda_graphs[(layer_idx, num_tokens)] = {
"graph": graph,
"input_hidden_states": dummy_hidden_states,
"output": output,
}
logger.debug(
"Captured CUDA graph for layer %s with %s tokens", layer_idx, num_tokens
)
def _run_ffn_computation(
self,
hidden_states: torch.Tensor,
layer_idx: int | None = None,
capture_mode: bool = False,
):
"""Run FFN computation for graph capture or replay."""
if layer_idx is None:
current_layer_idx = self._get_current_layer_idx() if not capture_mode else 0
else:
current_layer_idx = layer_idx
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size > 1:
# Handle TP case: all-gather tensors from all TP ranks
gathered_hidden_states = tensor_model_parallel_all_gather(
hidden_states, dim=0
)
ffn_output = self.model.compute_ffn_output(
gathered_hidden_states, current_layer_idx
)
# Extract the output corresponding to current rank
start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank()
end_idx = start_idx + hidden_states.shape[0]
rank_ffn_output = ffn_output[start_idx:end_idx, :]
else:
# Single TP case
rank_ffn_output = self.model.compute_ffn_output(
hidden_states, current_layer_idx
)
return rank_ffn_output
def _find_cuda_graph(self, layer_idx: int, num_tokens: int):
"""Find the smallest graph that can handle the given layer and
number of tokens."""
if not self.use_cuda_graph:
return None
# Find the minimum capture size that can handle num_tokens for this
# layer
for capture_size in self.cudagraph_batch_sizes:
if num_tokens <= capture_size:
return self._cuda_graphs.get((layer_idx, capture_size))
return None
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
"""FFN servers don't use samplers."""
pass
def update_config(self, overrides: dict[str, Any]) -> None:
"""Update configuration for FFN model runner."""
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, (
f"Config `{config_name}` not supported. "
f"Allowed configs: {allowed_config_names}"
)
config = getattr(self, config_name)
from vllm.config import update_config
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def reload_weights(self) -> None:
"""Reload model weights for FFN model runner."""
assert getattr(self, "model", None) is not None, (
"Cannot reload weights before model is loaded."
)
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model = self.get_model()
model_loader.load_weights(model, model_config=self.model_config)
@property
def lora_config(self):
"""FFN servers don't support LoRA."""
return None
@property
def is_pooling_model(self) -> bool:
"""FFN servers are not pooling models."""
return False
def _dummy_pooler_run(self, hidden_states: torch.Tensor):
"""FFN servers don't have poolers."""
pass
def get_supported_tasks(self):
"""Get supported tasks for FFN model runner."""
return []
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
"""Get number of input tokens for FFN model runner."""
return num_scheduled_tokens
def take_draft_token_ids(self, **kwargs):
"""FFN servers don't support draft tokens."""
pass
@property
def eplb_state(self):
"""FFN servers don't have EPLB state."""
return None
def ensure_kv_transfer_shutdown(self):
"""FFN servers don't need KV transfer shutdown."""
pass
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
"""FFN servers don't support tensorized model saving."""
raise NotImplementedError("FFN servers don't support tensorized model saving")

View File

@ -37,6 +37,7 @@ from vllm.config import (
get_layers_from_vllm_config,
update_config,
)
from vllm.distributed.afd_transfer import AFDConnectorFactory
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
@ -45,11 +46,13 @@ from vllm.distributed.parallel_state import (
get_dcp_group,
get_pp_group,
get_tp_group,
get_world_group,
graph_capture,
is_global_first_rank,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import (
AFDMetadata,
BatchDescriptor,
set_forward_context,
)
@ -547,6 +550,16 @@ class GPUModelRunner(
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
# init AFD config
self.afd_config = vllm_config.afd_config
if self.afd_config and self.afd_config.afd_role == "attention":
self.afd_connector = AFDConnectorFactory.create_connector(
get_world_group().rank, get_world_group().local_rank, vllm_config
)
self.afd_connector.init_afd_connector()
self.num_stages = self.afd_config.num_afd_stages
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
self.kv_sharing_fast_prefill_logits_indices = None
@ -606,6 +619,25 @@ class GPUModelRunner(
self.kv_connector_output: KVConnectorOutput | None = None
self.layerwise_nvtx_hooks_registered = False
profile_dir = (
"./profiler_logs/attn"
if self.afd_config is not None and self.afd_config.afd_role == "attention"
else "./profiler_logs/normal"
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=6000 + 4000, warmup=1, active=30, repeat=1
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
record_shapes=True,
profile_memory=False,
with_stack=False,
)
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@ -1303,7 +1335,7 @@ class GPUModelRunner(
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
logger.info(f"nums_reqs: {num_reqs}")
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
@ -2877,6 +2909,7 @@ class GPUModelRunner(
# decoder.
allow_dp_padding = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
or self.afd_config is not None
)
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
@ -2958,6 +2991,38 @@ class GPUModelRunner(
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
self.layerwise_nvtx_hooks_registered = True
def _build_afd_metadata(
self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int
):
afd_metadata = None
if self.afd_config:
# For prefill, compute tokens per stage based on actual token
# counts
afd_tokens_start_loc = [0]
afd_tokens_lens = []
if ubatch_slices and len(ubatch_slices) > 1:
afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices]
afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices]
logger.info(
f"afd_tokens_start_loc: {afd_tokens_start_loc} "
f"afd_reqs_start_loc: {afd_reqs_start_loc} "
f"ubatch_slices: {ubatch_slices}"
)
afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices]
else:
afd_tokens_start_loc = [0]
afd_reqs_start_loc = [0]
afd_tokens_lens = [num_tokens_unpadded]
afd_metadata = AFDMetadata(
afd_tokens_start_loc=afd_tokens_start_loc,
afd_reqs_start_loc=afd_reqs_start_loc,
afd_stage_idx=0,
afd_connector=self.afd_connector,
afd_tokens_lens=afd_tokens_lens,
num_of_stages=len(ubatch_slices) if ubatch_slices else 1,
)
return afd_metadata
@torch.inference_mode()
def execute_model(
self,
@ -3130,6 +3195,11 @@ class GPUModelRunner(
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
afd_metadata = self._build_afd_metadata(
ubatch_slices_padded, num_tokens_unpadded
)
self.profiler.step()
# Run the model.
# Use persistent buffers for CUDA graphs.
with (
@ -3141,10 +3211,15 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
afd_metadata=afd_metadata,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
):
if input_ids is not None:
logger.info(f"input_ids: {input_ids.shape}")
if inputs_embeds is not None:
logger.info(f"inputs_embeds: {inputs_embeds.shape}")
model_output = self._model_forward(
input_ids=input_ids,
positions=positions,
@ -4269,6 +4344,10 @@ class GPUModelRunner(
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_padded
afd_metadata = self._build_afd_metadata(
ubatch_slices_padded, num_tokens_unpadded
)
with (
self.maybe_randomize_inputs(input_ids, inputs_embeds),
set_forward_context(
@ -4279,6 +4358,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
afd_metadata=afd_metadata,
),
):
outputs = self.model(
@ -4814,7 +4894,6 @@ class GPUModelRunner(
kv_cache_spec,
kv_cache_group_id,
)
attn_groups.append(attn_group)
return attn_groups
@ -5491,6 +5570,11 @@ class GPUModelRunner(
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
def initialize_afd_connector(self) -> None:
"""Initialize AFD connector if available."""
if hasattr(self, "afd_connector") and self.afd_connector:
self.afd_connector.init_afd_connector()
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.

View File

@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import (
AFDMetadata,
DPMetadata,
create_forward_context,
get_forward_context,
@ -126,7 +127,10 @@ class UBatchWrapper:
comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel:
if (
vllm_config.parallel_config.enable_expert_parallel
and not vllm_config.afd_config
):
# Currently only DeepEP highthroughput supports SM control so this
# only affects that case.
ep_group = get_ep_group()
@ -303,6 +307,7 @@ class UBatchWrapper:
dp_metadata,
batch_descriptor,
cudagraph_runtime_mode,
afd_metadata,
) -> list[UbatchMetadata]:
# Create one forward context per ubatch
forward_contexts = []
@ -314,6 +319,7 @@ class UBatchWrapper:
dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode,
afd_metadata=afd_metadata,
)
)
@ -353,6 +359,62 @@ class UBatchWrapper:
return ubatch_metadata
def _make_afd_ubatch_metadata(
self,
ubatch_slices,
attn_metadata,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
dp_metadata,
afd_metadata,
) -> AFDMetadata:
if ubatch_slices is None:
afd_metadata.input_ids_list.append(input_ids)
afd_metadata.positions_list.append(positions)
afd_metadata.inputs_embeds_list.append(inputs_embeds)
afd_metadata.intermediate_tensors_list.append(intermediate_tensors)
afd_metadata.attn_metadata_list.append(attn_metadata)
afd_metadata.dp_metadata_list.append(dp_metadata)
else:
for i, ubatch_slice in enumerate(ubatch_slices):
(
sliced_input_ids,
sliced_positions,
sliced_inputs_embeds,
sliced_intermediate_tensors,
) = self._slice_model_inputs(
ubatch_slice.token_slice,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
)
dp_size = self.vllm_config.parallel_config.data_parallel_size
ubatch_num_tokens_across_dp = torch.tensor(
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
)
ubatch_dp_metadata = DPMetadata.make(
self.vllm_config.parallel_config,
ubatch_slice.num_tokens,
ubatch_num_tokens_across_dp,
)
afd_metadata.input_ids_list.append(sliced_input_ids)
afd_metadata.positions_list.append(sliced_positions)
afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds)
afd_metadata.intermediate_tensors_list.append(
sliced_intermediate_tensors
)
afd_metadata.attn_metadata_list.append(
attn_metadata[i] if attn_metadata is not None else None
)
afd_metadata.dp_metadata_list.append(ubatch_dp_metadata)
return afd_metadata
def _slice_model_inputs(
self,
tokens_slice: slice,
@ -385,6 +447,34 @@ class UBatchWrapper:
batch_descriptor = forward_context.batch_descriptor
ubatch_slices = forward_context.ubatch_slices
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
afd_metadata = forward_context.afd_metadata
attn_metadata = forward_context.attn_metadata
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
inputs_embeds = kwargs["inputs_embeds"]
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
if self.vllm_config.afd_config:
afd_metadata = self._make_afd_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
dp_metadata=dp_metadata,
afd_metadata=afd_metadata,
)
forward_context.afd_metadata = afd_metadata
if cudagraph_runtime_mode is CUDAGraphMode.NONE:
return self.runnable(*args, **kwargs)
else:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
@ -394,6 +484,7 @@ class UBatchWrapper:
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
@ -405,18 +496,9 @@ class UBatchWrapper:
assert self.cudagraph_wrapper is not None
return self.cudagraph_wrapper(*args, **kwargs)
attn_metadata = forward_context.attn_metadata
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
inputs_embeds = kwargs["inputs_embeds"]
compute_stream = torch.cuda.current_stream()
dp_metadata = forward_context.dp_metadata
# We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None
ubatch_dp_metadata = []
@ -448,6 +530,7 @@ class UBatchWrapper:
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
afd_metadata=afd_metadata,
)
with self.sm_control:
return self._capture_ubatches(ubatch_metadata, self.model)
@ -470,6 +553,7 @@ class UBatchWrapper:
dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
afd_metadata=afd_metadata,
)
with self.sm_control:
return self._run_ubatches(ubatch_metadata, self.model)

View File

@ -6,7 +6,7 @@ import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Optional, cast
import numpy as np
import torch
@ -52,6 +52,8 @@ from vllm.v1.outputs import (
ModelRunnerOutput,
)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_ffn_model_runner import GPUFFNModelRunner
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
@ -248,8 +250,11 @@ class Worker(WorkerBase):
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
self.model_runner: GPUModelRunner | GPUFFNModelRunner | GPUModelRunnerV2
if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server:
self.model_runner = GPUFFNModelRunner(self.vllm_config, self.device)
# Construct the model runner
if self.use_v2_model_runner:
elif self.use_v2_model_runner:
from vllm.v1.worker.gpu.model_runner import (
GPUModelRunner as GPUModelRunnerV2,
)
@ -574,8 +579,16 @@ class Worker(WorkerBase):
@torch.inference_mode()
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
self, scheduler_output: Optional["SchedulerOutput"] = None
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
# FFN server mode: direct execution without pipeline parallelism
if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server:
return self.model_runner.execute_model(scheduler_output)
if scheduler_output is None:
raise ValueError("scheduler_output is required in normal inference mode")
# Normal inference mode
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -682,6 +695,53 @@ class Worker(WorkerBase):
# worker will always be healthy as long as it's running.
return
def start_ffn_server_loop(self) -> None:
"""Start FFN server loop for AFD FFN workers"""
if not (
self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server
):
return
self.model_runner.capture_model()
self.model_runner.initialize_afd_connector()
if self.profiler:
self.profiler.start()
for _ in range(1000): # FIXME: hardcoded profiler iterations
self.model_runner.execute_model(scheduler_output=None)
torch.cuda.synchronize() # Ensure GPU operations complete
self.profiler.stop()
print(self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
import threading
self._ffn_shutdown_event = threading.Event()
def ffn_worker_loop():
# Set CUDA device for this thread (thread-local context)
torch.cuda.set_device(self.device)
logger.info("FFN worker loop started")
try:
while not self._ffn_shutdown_event.is_set():
# Execute FFN computation
self.model_runner.execute_model(scheduler_output=None)
except Exception as e:
logger.error("FFN worker loop error: %s", e)
raise
self._ffn_thread = threading.Thread(target=ffn_worker_loop, daemon=True)
self._ffn_thread.start()
logger.info("FFN server loop started in worker")
def stop_ffn_server_loop(self) -> None:
"""Stop FFN server loop"""
if hasattr(self, "_ffn_shutdown_event"):
self._ffn_shutdown_event.set()
if hasattr(self, "_ffn_thread"):
self._ffn_thread.join(timeout=5)
logger.info("FFN server loop stopped")
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group