mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-17 08:37:08 +08:00
Merge e7254d8994a4caf49e4cd08b604657b7ee8ae418 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
92d38e41c8
19
examples/online_serving/afd_deepseek_v2/README.md
Normal file
19
examples/online_serving/afd_deepseek_v2/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# P2P Connector
|
||||
P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances.
|
||||
|
||||
When the --enable-dbo flag is currently enabled, the num_stage parameter becomes ineffective, and the actual number of microbatches is 2.
|
||||
|
||||
Currently, the P2PConnector only supports scenarios where the number of dies of A equals that of F. Asymmetric configurations will be supported in future updates.
|
||||
|
||||
1. Attn
|
||||
|
||||
```
|
||||
vllm serve "/path/to/DeepSeek-V2-Lite" --data_parallel_size=2 --enable_expert_parallel --enforce_eager --enable-dbo --dbo-prefill-token-threshold 12 --dbo-decode-token-threshold 2 --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"2","afd_extra_config":{"afd_size":"2A2F"}}'
|
||||
|
||||
```
|
||||
|
||||
2. FFN
|
||||
|
||||
```
|
||||
vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"2", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}'
|
||||
```
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.attention import AttentionConfig
|
||||
from vllm.config.afd import AFDConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
@ -66,6 +67,8 @@ __all__ = [
|
||||
"KVEventsConfig",
|
||||
# From vllm.config.kv_transfer
|
||||
"KVTransferConfig",
|
||||
# AFD (Attention FFN Disaggregation) configuration
|
||||
"AFDConfig",
|
||||
# From vllm.config.load
|
||||
"LoadConfig",
|
||||
# From vllm.config.lora
|
||||
|
||||
79
vllm/config/afd.py
Normal file
79
vllm/config/afd.py
Normal file
@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class AFDConfig:
|
||||
"""Configuration for AFD (Attention FFN Disaggregation) distributed
|
||||
computation."""
|
||||
|
||||
afd_connector: str = "dummy"
|
||||
"""The AFD connector for vLLM to communicate between attention and FFN
|
||||
nodes. Available connectors: 'dummy', 'p2pconnector'"""
|
||||
|
||||
afd_role: Literal["attention", "ffn"] = "attention"
|
||||
"""Role of this vLLM instance in AFD. 'attention' for attention workers,
|
||||
'ffn' for FFN servers."""
|
||||
|
||||
afd_port: int = 1239
|
||||
"""Port number for stepmesh parameter server communication."""
|
||||
|
||||
afd_host: str = "127.0.0.1"
|
||||
"""Host address for stepmesh parameter server communication."""
|
||||
|
||||
num_afd_stages: int = 3
|
||||
"""Number of pipeline stages for stage parallelism."""
|
||||
|
||||
num_attention_servers: int = 1
|
||||
"""Number of attention servers."""
|
||||
|
||||
num_ffn_servers: int = 1
|
||||
"""Number of FFN servers."""
|
||||
|
||||
afd_server_rank: int = 0
|
||||
"""Rank of this AFD server."""
|
||||
|
||||
afd_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""Extra configuration for specific AFD connectors."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# AFD configuration affects the computation graph structure
|
||||
# as it changes how FFN computation is performed
|
||||
factors: list[Any] = [
|
||||
self.afd_connector,
|
||||
self.afd_role,
|
||||
self.num_afd_stages,
|
||||
self.num_attention_servers,
|
||||
self.num_ffn_servers,
|
||||
]
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
@property
|
||||
def is_attention_server(self) -> bool:
|
||||
"""Check if this instance is configured as an attention server."""
|
||||
return self.afd_role == "attention"
|
||||
|
||||
@property
|
||||
def is_ffn_server(self) -> bool:
|
||||
"""Check if this instance is configured as an FFN server."""
|
||||
return self.afd_role == "ffn"
|
||||
@ -553,7 +553,7 @@ class ParallelConfig:
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
backend: DistributedExecutorBackend = "mp"
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.utils import random_uuid
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .attention import AttentionConfig
|
||||
from .afd import AFDConfig
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
from .device import DeviceConfig
|
||||
@ -227,6 +228,8 @@ class VllmConfig:
|
||||
"""The configurations for event publishing."""
|
||||
ec_transfer_config: ECTransferConfig | None = None
|
||||
"""The configurations for distributed EC cache transfer."""
|
||||
afd_config: AFDConfig | None = None
|
||||
"""AFD (Attention FFN Disaggregation) configuration."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@ -318,6 +321,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.ec_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.afd_config:
|
||||
vllm_factors.append(self.afd_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = safe_hash(
|
||||
@ -872,16 +879,17 @@ class VllmConfig:
|
||||
|
||||
if self.parallel_config.use_ubatching:
|
||||
a2a_backend = self.parallel_config.all2all_backend
|
||||
assert a2a_backend in [
|
||||
"deepep_low_latency",
|
||||
"deepep_high_throughput",
|
||||
], (
|
||||
"Microbatching currently only supports the deepep_low_latency and "
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
|
||||
"supported. To fix use --all2all-backend=deepep_low_latency or "
|
||||
"--all2all-backend=deepep_high_throughput and install the DeepEP"
|
||||
" kernels."
|
||||
)
|
||||
if self.afd_config is None:
|
||||
assert a2a_backend in [
|
||||
"deepep_low_latency",
|
||||
"deepep_high_throughput",
|
||||
], (
|
||||
"Microbatching currently only supports the deepep_low_latency and "
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
|
||||
"supported. To fix use --all2all-backend=deepep_low_latency or "
|
||||
"--all2all-backend=deepep_high_throughput and install the DeepEP"
|
||||
" kernels."
|
||||
)
|
||||
|
||||
if not self.model_config.disable_cascade_attn:
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
12
vllm/distributed/afd_transfer/__init__.py
Normal file
12
vllm/distributed/afd_transfer/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""AFD (Attention-FFN Disaggregation) transfer components.
|
||||
|
||||
This module provides the distributed infrastructure for AFD, enabling
|
||||
disaggregated FFN computation across different machines while keeping
|
||||
attention computation local.
|
||||
"""
|
||||
|
||||
from .afd_connector import AFDConnectorBase, AFDConnectorFactory, AFDConnectorMetadata
|
||||
|
||||
__all__ = ["AFDConnectorBase", "AFDConnectorMetadata", "AFDConnectorFactory"]
|
||||
13
vllm/distributed/afd_transfer/afd_connector/__init__.py
Normal file
13
vllm/distributed/afd_transfer/afd_connector/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""AFD connector implementations for different transport backends."""
|
||||
|
||||
from .base import AFDConnectorBase
|
||||
from .factory import AFDConnectorFactory
|
||||
from .metadata import AFDConnectorMetadata
|
||||
|
||||
__all__ = [
|
||||
"AFDConnectorBase",
|
||||
"AFDConnectorFactory",
|
||||
"AFDConnectorMetadata",
|
||||
]
|
||||
139
vllm/distributed/afd_transfer/afd_connector/base.py
Normal file
139
vllm/distributed/afd_transfer/afd_connector/base.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
AFDConnectorBase Class for Distributed AFD FFN computation
|
||||
|
||||
The class provides the four core AFD communication interfaces:
|
||||
1. send_attn_output(): Send attention output to FFN servers (Attention Worker)
|
||||
2. recv_ffn_output(): Receive FFN computation result (Attention Worker)
|
||||
3. recv_attn_output(): Receive attention output from workers (FFN Server)
|
||||
4. send_ffn_output(): Send FFN computation result back (FFN Server)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
from .metadata import AFDConnectorMetadata
|
||||
|
||||
|
||||
class AFDConnectorBase(ABC):
|
||||
"""
|
||||
Abstract base class for AFD connectors.
|
||||
|
||||
This provides the four core interfaces for AFD communication between
|
||||
attention workers and FFN servers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
"""Initialize the AFD connector.
|
||||
|
||||
Args:
|
||||
rank: Global rank of this process
|
||||
local_rank: Local rank within the node
|
||||
config: VllmConfig containing AFDConfig
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the connector and release resources."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_afd_connector(self) -> None:
|
||||
"""Initialize the AFD connector."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the connector is initialized and ready to use.
|
||||
|
||||
Returns:
|
||||
bool: True if the connector is initialized, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_connector_rank(self) -> int:
|
||||
"""Get the rank of this connector."""
|
||||
return getattr(self, "rank", 0)
|
||||
|
||||
def get_connector_local_rank(self) -> int:
|
||||
"""Get the local rank of this connector."""
|
||||
return getattr(self, "local_rank", 0)
|
||||
|
||||
@abstractmethod
|
||||
def send_attn_output(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
metadata: "AFDConnectorMetadata",
|
||||
) -> Any:
|
||||
"""Send attention output to FFN servers.
|
||||
|
||||
Args:
|
||||
hidden_states: Attention output tensor
|
||||
metadata: AFD metadata containing layer_idx, stage_idx, seq_len info
|
||||
|
||||
Returns:
|
||||
Any: Handle for tracking this request (backend-specific)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_ffn_output(
|
||||
self,
|
||||
handle: Any,
|
||||
) -> torch.Tensor:
|
||||
"""Wait for and receive FFN computation result.
|
||||
|
||||
Args:
|
||||
handle: Handle returned by send_attn_output()
|
||||
|
||||
Returns:
|
||||
torch.Tensor: FFN computation result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_attn_output(
|
||||
self,
|
||||
timeout_ms: int | None = None,
|
||||
) -> tuple[torch.Tensor, "AFDConnectorMetadata"]:
|
||||
"""Receive attention output from attention workers.
|
||||
|
||||
Args:
|
||||
timeout_ms: Optional timeout in milliseconds
|
||||
|
||||
Returns:
|
||||
tuple: (hidden_states, metadata)
|
||||
- hidden_states: Concatenated attention outputs
|
||||
- metadata: Inferred AFD metadata containing
|
||||
seq_lens and other info
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_ffn_output(
|
||||
self,
|
||||
ffn_output: torch.Tensor,
|
||||
metadata: "AFDConnectorMetadata",
|
||||
) -> None:
|
||||
"""Send FFN computation result back to attention workers.
|
||||
|
||||
Args:
|
||||
ffn_output: Computed FFN result
|
||||
metadata: AFD metadata containing seq_lens
|
||||
for splitting and routing info
|
||||
"""
|
||||
raise NotImplementedError
|
||||
211
vllm/distributed/afd_transfer/afd_connector/dummy_connector.py
Normal file
211
vllm/distributed/afd_transfer/afd_connector/dummy_connector.py
Normal file
@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Dummy AFD Connector for testing and local development.
|
||||
|
||||
This connector provides a no-op AFDConnectorBase interface,
|
||||
useful for testing and development scenarios where actual
|
||||
distributed FFN computation is not needed.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import AFDConnectorBase
|
||||
from .metadata import AFDConnectorMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DummyAFDConnector(AFDConnectorBase):
|
||||
"""Dummy AFD connector that returns zero tensors.
|
||||
|
||||
This connector is useful for:
|
||||
1. Testing AFD infrastructure without actual remote computation
|
||||
2. Development scenarios where FFN computation should be disabled
|
||||
3. Fallback behavior when remote FFN servers are unavailable
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
"""Initialize the dummy AFD connector.
|
||||
|
||||
Args:
|
||||
rank: Global rank of this process
|
||||
local_rank: Local rank within the node
|
||||
config: VllmConfig containing AFDConfig
|
||||
"""
|
||||
self.afd_config = config.afd_config
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self._is_initialized = False
|
||||
self.hidden_size = config.model_config.hf_config.hidden_size
|
||||
self.num_stages = config.afd_config.num_afd_stages
|
||||
|
||||
self.events: deque = deque(maxlen=self.num_stages)
|
||||
|
||||
logger.info("DummyAFDConnector initialized for rank %s", rank)
|
||||
|
||||
self.init_afd_connector()
|
||||
|
||||
def init_afd_connector(self) -> None:
|
||||
"""Initialize the dummy connector.
|
||||
|
||||
This is a no-op for the dummy connector.
|
||||
"""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing DummyAFDConnector (no-op)")
|
||||
self._is_initialized = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the dummy connector.
|
||||
|
||||
This is a no-op for the dummy connector.
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Closing DummyAFDConnector (no-op)")
|
||||
self._is_initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the connector is initialized.
|
||||
|
||||
Returns:
|
||||
bool: True if initialized, False otherwise
|
||||
"""
|
||||
return self._is_initialized
|
||||
|
||||
def send_attn_output(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
metadata: AFDConnectorMetadata,
|
||||
) -> Any:
|
||||
"""
|
||||
Send attention output to FFN servers (dummy implementation).
|
||||
"""
|
||||
logger.debug(
|
||||
"DummyAFDConnector: send_attn_output layer=%s, stage=%s",
|
||||
metadata.layer_idx,
|
||||
metadata.stage_idx,
|
||||
)
|
||||
|
||||
# Validate metadata consistency
|
||||
if not metadata.validate_tensor_shape(hidden_states.shape):
|
||||
raise ValueError(
|
||||
"Tensor shape %s doesn't match metadata %s",
|
||||
hidden_states.shape,
|
||||
metadata,
|
||||
)
|
||||
|
||||
if not metadata.is_single_sequence:
|
||||
raise ValueError("Attention side should have single sequence")
|
||||
|
||||
self.events.append((None, metadata))
|
||||
|
||||
return None
|
||||
|
||||
def recv_ffn_output(
|
||||
self,
|
||||
timeout_ms: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Receive FFN computation result (dummy implementation)."""
|
||||
logger.debug("DummyAFDConnector: recv_ffn_output timeout_ms=%s", timeout_ms)
|
||||
|
||||
_, metadata = self.events.popleft()
|
||||
seq_len = metadata.seq_lens[0] # Single sequence for attention side
|
||||
return torch.zeros(
|
||||
seq_len,
|
||||
self.hidden_size,
|
||||
dtype=metadata.dtype,
|
||||
device=metadata.device,
|
||||
)
|
||||
|
||||
def recv_attn_output(
|
||||
self,
|
||||
timeout_ms: int | None = None,
|
||||
) -> tuple[torch.Tensor, AFDConnectorMetadata]:
|
||||
"""
|
||||
Receive attention output from attention workers (dummy implementation).
|
||||
"""
|
||||
logger.debug("DummyAFDConnector: recv_attn_output timeout_ms=%s", timeout_ms)
|
||||
|
||||
# Generate dummy data that simulates multiple attention workers
|
||||
dummy_seq_lens = [
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
] # Variable sequence lengths from different workers
|
||||
total_tokens = sum(dummy_seq_lens)
|
||||
|
||||
dummy_tensor = torch.zeros(
|
||||
total_tokens, self.hidden_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
# Create dummy metadata
|
||||
dummy_metadata = AFDConnectorMetadata.create_ffn_metadata(
|
||||
layer_idx=0, # Dummy layer
|
||||
stage_idx=0, # Dummy stage
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.device("cuda"),
|
||||
seq_lens=dummy_seq_lens,
|
||||
request_id=f"dummy_ffn_batch_{time.time()}",
|
||||
)
|
||||
|
||||
# Cache metadata for send_ffn_output
|
||||
self._current_metadata = dummy_metadata
|
||||
time.sleep(1)
|
||||
|
||||
return dummy_tensor, dummy_metadata
|
||||
|
||||
def send_ffn_output(
|
||||
self,
|
||||
ffn_output: torch.Tensor,
|
||||
metadata: AFDConnectorMetadata,
|
||||
) -> None:
|
||||
"""Send FFN computation result back (dummy implementation)."""
|
||||
logger.debug(
|
||||
"DummyAFDConnector: send_ffn_output layer=%s, stage=%s",
|
||||
metadata.layer_idx,
|
||||
metadata.stage_idx,
|
||||
)
|
||||
|
||||
# Validate that ffn_output shape matches metadata
|
||||
if not metadata.validate_tensor_shape(ffn_output.shape):
|
||||
logger.warning(
|
||||
"FFN output shape %s doesn't match metadata %s",
|
||||
ffn_output.shape,
|
||||
metadata,
|
||||
)
|
||||
|
||||
# Log the splitting information for debugging
|
||||
logger.debug(
|
||||
"DummyAFDConnector: Split FFN output into %s parts with lengths %s",
|
||||
metadata.num_sequences,
|
||||
metadata.seq_lens,
|
||||
)
|
||||
|
||||
# Simulate splitting (for logging purposes)
|
||||
if metadata.get_split_indices():
|
||||
split_outputs = torch.split(ffn_output, metadata.seq_lens, dim=0)
|
||||
logger.debug(
|
||||
"DummyAFDConnector: Split shapes: %s",
|
||||
[s.shape for s in split_outputs],
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
# No-op for dummy connector - just log the operation
|
||||
95
vllm/distributed/afd_transfer/afd_connector/factory.py
Normal file
95
vllm/distributed/afd_transfer/afd_connector/factory.py
Normal file
@ -0,0 +1,95 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Factory for creating AFD connectors based on configuration."""
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import AFDConnectorBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AFDConnectorFactory:
|
||||
_registry: dict[str, Callable[[], type[AFDConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[AFDConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(
|
||||
cls, rank: int, local_rank: int, config: "VllmConfig"
|
||||
) -> AFDConnectorBase:
|
||||
"""Create an AFD connector based on the configuration.
|
||||
|
||||
Args:
|
||||
rank: Global rank of this process
|
||||
local_rank: Local rank within the node
|
||||
config: VllmConfig containing AFDConfig
|
||||
|
||||
Returns:
|
||||
AFDConnectorBase: The created connector instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the transport backend is not supported
|
||||
ImportError: If required dependencies are not available
|
||||
"""
|
||||
afd_config = config.afd_config
|
||||
connector_name = afd_config.afd_connector
|
||||
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
assert issubclass(connector_cls, AFDConnectorBase)
|
||||
return connector_cls(rank, local_rank, config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(cls, connector_name: str) -> type[AFDConnectorBase]:
|
||||
"""Get the connector class for a given connector name.
|
||||
|
||||
Args:
|
||||
connector_name: The connector name
|
||||
|
||||
Returns:
|
||||
type[AFDConnectorBase]: The connector class
|
||||
|
||||
Raises:
|
||||
ValueError: If the connector name is not supported
|
||||
"""
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
return cls._registry[connector_name]()
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
|
||||
AFDConnectorFactory.register_connector(
|
||||
"dummy",
|
||||
"vllm.distributed.afd_transfer.afd_connector.dummy_connector",
|
||||
"DummyAFDConnector",
|
||||
)
|
||||
|
||||
AFDConnectorFactory.register_connector(
|
||||
"p2pconnector",
|
||||
"vllm.distributed.afd_transfer.afd_connector.p2p_connector",
|
||||
"P2PAFDConnector",
|
||||
)
|
||||
175
vllm/distributed/afd_transfer/afd_connector/metadata.py
Normal file
175
vllm/distributed/afd_transfer/afd_connector/metadata.py
Normal file
@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""AFD metadata definitions for communication between attention and
|
||||
FFN workers."""
|
||||
|
||||
import time
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class FFNNeedForwardData:
|
||||
def __init__(
|
||||
self,
|
||||
moe_comm_method: typing.Any,
|
||||
num_input_tokens: int,
|
||||
with_prefill: bool,
|
||||
total_num_scheduled_tokens: int | None,
|
||||
is_dummy_run: bool = False,
|
||||
):
|
||||
self.moe_comm_method = moe_comm_method
|
||||
self.num_input_tokens = num_input_tokens
|
||||
self.with_prefill = with_prefill
|
||||
self.total_num_scheduled_tokens = total_num_scheduled_tokens
|
||||
self.is_dummy_run = is_dummy_run
|
||||
|
||||
|
||||
@dataclass
|
||||
class AFDConnectorMetadata:
|
||||
"""Lightweight AFD metadata containing core information needed for
|
||||
communication."""
|
||||
|
||||
layer_idx: int
|
||||
stage_idx: int
|
||||
seq_lens: list[int] # Length of each sequence, supports variable length and
|
||||
# multiple sequences
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
topk_idx: torch.Tensor | None = None # indices token which expert to be sended
|
||||
topk_weights: torch.Tensor | None = None # the expert weights
|
||||
moe_expert_num: int | None = None # number of moe experts
|
||||
shared_expert_num: int | None = None # number of share experts
|
||||
scale: torch.Tensor | None = None # quant scale
|
||||
expertTokenNumsOut: torch.Tensor | None = (
|
||||
None # The number of tokens received by each expert is used as input for the subsequent GMM.
|
||||
)
|
||||
recv_handle_list: list[Any] | None = (
|
||||
None # the communication handles (list of Work objects returned by torch.distributed.irecv)
|
||||
)
|
||||
|
||||
# Optional fields for debugging and extensibility
|
||||
request_id: str | None = None
|
||||
timestamp: float | None = None
|
||||
"""ffn need forward data"""
|
||||
ffn_need_forward_data: FFNNeedForwardData | None = None
|
||||
num_of_stages: int = 1
|
||||
afd_tokens_lens: list = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate data consistency."""
|
||||
if not self.seq_lens:
|
||||
raise ValueError("seq_lens cannot be empty")
|
||||
if any(length <= 0 for length in self.seq_lens):
|
||||
raise ValueError("All sequence lengths must be positive")
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Total number of tokens."""
|
||||
return sum(self.seq_lens)
|
||||
|
||||
@property
|
||||
def num_sequences(self) -> int:
|
||||
"""Number of sequences."""
|
||||
return len(self.seq_lens)
|
||||
|
||||
@property
|
||||
def is_single_sequence(self) -> bool:
|
||||
"""Whether this is a single sequence (attention side characteristic)."""
|
||||
return len(self.seq_lens) == 1
|
||||
|
||||
@property
|
||||
def is_multi_sequence(self) -> bool:
|
||||
"""Whether this is multiple sequences (FFN side characteristic)."""
|
||||
return len(self.seq_lens) > 1
|
||||
|
||||
@classmethod
|
||||
def create_attention_metadata(
|
||||
cls,
|
||||
layer_idx: int,
|
||||
stage_idx: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
request_id: str | None = None,
|
||||
ffn_need_forward_data: FFNNeedForwardData | None = None,
|
||||
num_of_stages: int = 1,
|
||||
afd_tokens_lens: list[int] = [],
|
||||
) -> "AFDConnectorMetadata":
|
||||
"""Create metadata for attention side (single sequence)."""
|
||||
return cls(
|
||||
layer_idx=layer_idx,
|
||||
stage_idx=stage_idx,
|
||||
seq_lens=[seq_len],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
request_id=request_id,
|
||||
ffn_need_forward_data=ffn_need_forward_data,
|
||||
timestamp=time.time(),
|
||||
num_of_stages=num_of_stages,
|
||||
afd_tokens_lens=afd_tokens_lens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_ffn_metadata(
|
||||
cls,
|
||||
layer_idx: int,
|
||||
stage_idx: int,
|
||||
seq_lens: list[int],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
request_id: str | None = None,
|
||||
) -> "AFDConnectorMetadata":
|
||||
"""Create metadata for FFN side (multiple sequences)."""
|
||||
return cls(
|
||||
layer_idx=layer_idx,
|
||||
stage_idx=stage_idx,
|
||||
seq_lens=seq_lens.copy(), # Prevent external modification
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
request_id=request_id,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
|
||||
def get_split_indices(self) -> list[int]:
|
||||
"""Get tensor split indices for FFN side output splitting."""
|
||||
if len(self.seq_lens) <= 1:
|
||||
return []
|
||||
|
||||
indices = []
|
||||
cumsum = 0
|
||||
for length in self.seq_lens[:-1]: # Exclude the last one
|
||||
cumsum += length
|
||||
indices.append(cumsum)
|
||||
return indices
|
||||
|
||||
def validate_tensor_shape(self, tensor_shape: tuple[int, ...]) -> bool:
|
||||
"""Validate if tensor shape is consistent with metadata."""
|
||||
if len(tensor_shape) < 1:
|
||||
return False
|
||||
return tensor_shape[0] == self.total_tokens
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary format for serialization and debugging."""
|
||||
return {
|
||||
"layer_idx": self.layer_idx,
|
||||
"stage_idx": self.stage_idx,
|
||||
"seq_lens": self.seq_lens,
|
||||
"dtype": self.dtype,
|
||||
"device": self.device,
|
||||
"total_tokens": self.total_tokens,
|
||||
"num_sequences": self.num_sequences,
|
||||
"request_id": self.request_id,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Friendly string representation."""
|
||||
return (
|
||||
f"AFDConnectorMetadata(layer={self.layer_idx}, "
|
||||
f"stage={self.stage_idx}, seq_lens={self.seq_lens}, "
|
||||
f"total_tokens={self.total_tokens}, dtype={self.dtype}, "
|
||||
f"device={self.device}, request_id={self.request_id})"
|
||||
)
|
||||
337
vllm/distributed/afd_transfer/afd_connector/p2p_connector.py
Normal file
337
vllm/distributed/afd_transfer/afd_connector/p2p_connector.py
Normal file
@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import re
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
TensorMetadata,
|
||||
init_afd_process_group,
|
||||
init_model_parallel_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.forward_context import (
|
||||
DPMetadata,
|
||||
get_forward_context,
|
||||
)
|
||||
|
||||
from .base import AFDConnectorBase
|
||||
from .metadata import AFDConnectorMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultProcessGroupSwitcher:
|
||||
def __init__(self, default_group, new_default_group):
|
||||
self.default_group = default_group
|
||||
self.new_default_group = new_default_group
|
||||
|
||||
def __enter__(self):
|
||||
_update_default_pg(self.new_default_group)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
_update_default_pg(self.default_group)
|
||||
|
||||
|
||||
class P2PAFDConnector(AFDConnectorBase):
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.config = config
|
||||
self._initialized: bool = False
|
||||
self._need_recv_metadata: bool = True
|
||||
self._tensor_metadata_list: dict[int, TensorMetadata] = {}
|
||||
self._current_afd_connector_metadata: AFDConnectorMetadata | None = None
|
||||
if getattr(self.config.model_config.hf_config, "text_config", None) is not None:
|
||||
self.num_hidden_layers: int = (
|
||||
self.config.model_config.hf_config.text_config.num_hidden_layers
|
||||
)
|
||||
else:
|
||||
self.num_hidden_layers: int = (
|
||||
self.config.model_config.hf_config.num_hidden_layers
|
||||
)
|
||||
|
||||
self.recv_attn_output_counter: int = 0
|
||||
self.recv_ffn_output_counter: int = 0
|
||||
self.dp_metadata_list: dict[int, DPMetadata] = {}
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connector and release resources."""
|
||||
# TODO: Implement proper resource clean up if needed.
|
||||
pass
|
||||
|
||||
def init_afd_connector(self) -> None:
|
||||
"""Initialize the AFD connector."""
|
||||
afd_size = self.config.afd_config.afd_extra_config.get("afd_size")
|
||||
role = self.config.afd_config.afd_role
|
||||
attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups())
|
||||
world_rank = self.rank if role == "attention" else self.rank + attn_size
|
||||
afd_pg = init_afd_process_group(
|
||||
backend="nccl",
|
||||
init_method=(
|
||||
f"tcp://{self.config.afd_config.afd_host}"
|
||||
f":{self.config.afd_config.afd_port}"
|
||||
),
|
||||
world_size=ffn_size + attn_size,
|
||||
rank=world_rank,
|
||||
group_name="afd",
|
||||
timeout=timedelta(minutes=2),
|
||||
)
|
||||
|
||||
# Construct rank lists for sub groups.
|
||||
# Each group contains one attention and one ffn rank.
|
||||
ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)]
|
||||
attn_ranks = [i for i in range(attn_size)]
|
||||
assert len(ffn_ranks) == len(attn_ranks), (
|
||||
"ffn_ranks and attn_ranks must have the same length"
|
||||
)
|
||||
default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg)
|
||||
with default_pg_switcher:
|
||||
sub_group_ranks = []
|
||||
for i in range(len(ffn_ranks)):
|
||||
ranks = [attn_ranks[i], ffn_ranks[i]]
|
||||
sub_group_ranks.append(ranks)
|
||||
# Create two independent groups:
|
||||
# a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn)
|
||||
# e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn)
|
||||
# The communication domain (rank range) is the same, but different group_name
|
||||
# creates independent groups.
|
||||
self.a2e_group = init_model_parallel_group(
|
||||
sub_group_ranks,
|
||||
self.local_rank,
|
||||
backend="nccl",
|
||||
group_name="a2e",
|
||||
)
|
||||
self.e2a_group = init_model_parallel_group(
|
||||
sub_group_ranks,
|
||||
self.local_rank,
|
||||
backend="nccl",
|
||||
group_name="e2a",
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the connector is initialized and ready to use.
|
||||
|
||||
Returns:
|
||||
bool: True if the connector is initialized, False otherwise.
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
def _build_tensor_metadata_list(
|
||||
self,
|
||||
tensor_metadata: TensorMetadata,
|
||||
connector_metadata: AFDConnectorMetadata,
|
||||
) -> dict[int, TensorMetadata]:
|
||||
tensor_metadata_list = {}
|
||||
num_of_stages = connector_metadata.num_of_stages
|
||||
for idx in range(num_of_stages):
|
||||
if idx == 0:
|
||||
tensor_metadata_list[0] = tensor_metadata
|
||||
else:
|
||||
new_size = list(tensor_metadata.size)
|
||||
new_size[0] = connector_metadata.afd_tokens_lens[idx]
|
||||
tensor_metadata_list[idx] = TensorMetadata(
|
||||
tensor_metadata.device,
|
||||
tensor_metadata.dtype,
|
||||
torch.Size(new_size),
|
||||
)
|
||||
return tensor_metadata_list
|
||||
|
||||
def _send_metadata(
|
||||
self,
|
||||
metadata: AFDConnectorMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
dst: int,
|
||||
process_group: GroupCoordinator,
|
||||
) -> None:
|
||||
if not torch.distributed.is_initialized() or process_group.world_size == 1:
|
||||
return []
|
||||
assert dst < process_group.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
tensor_metadata = TensorMetadata(
|
||||
hidden_states.device.type, hidden_states.dtype, hidden_states.size()
|
||||
)
|
||||
metadata_tuple = (metadata, tensor_metadata)
|
||||
process_group.send_object(metadata_tuple, dst=dst)
|
||||
self._tensor_metadata_list = self._build_tensor_metadata_list(
|
||||
tensor_metadata, metadata
|
||||
)
|
||||
|
||||
def _recv_metadata(
|
||||
self,
|
||||
src: int,
|
||||
process_group: GroupCoordinator,
|
||||
) -> None:
|
||||
(self._current_afd_connector_metadata, tensor_metadata) = (
|
||||
process_group.recv_object(src=src)
|
||||
)
|
||||
self._tensor_metadata_list = self._build_tensor_metadata_list(
|
||||
tensor_metadata, self._current_afd_connector_metadata
|
||||
)
|
||||
if self.config.parallel_config.data_parallel_size > 1:
|
||||
logger.info(
|
||||
"jcz recv_metadata num_of_stages:{}".format(
|
||||
self._current_afd_connector_metadata.num_of_stages
|
||||
)
|
||||
)
|
||||
for stage_idx in range(self._current_afd_connector_metadata.num_of_stages):
|
||||
num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0]
|
||||
self.dp_metadata_list[stage_idx] = DPMetadata.make(
|
||||
self.config.parallel_config,
|
||||
num_tokens_per_ubatch,
|
||||
torch.tensor(
|
||||
[num_tokens_per_ubatch]
|
||||
* self.config.parallel_config.data_parallel_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"jcz recv_metadata self.dp_metadata_list:{}".format(
|
||||
self.dp_metadata_list
|
||||
)
|
||||
)
|
||||
|
||||
def _send_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
dst: int,
|
||||
process_group: GroupCoordinator,
|
||||
) -> None:
|
||||
if not torch.distributed.is_initialized() or process_group.world_size == 1:
|
||||
return []
|
||||
assert dst < process_group.world_size, f"Invalid dst rank ({dst})"
|
||||
assert not hidden_states.is_cpu, "Hidden states must be on GPU"
|
||||
torch.distributed.send(
|
||||
hidden_states,
|
||||
dst=process_group.ranks[dst],
|
||||
group=process_group.device_group,
|
||||
)
|
||||
|
||||
def _recv_hidden_states(
|
||||
self,
|
||||
src: int,
|
||||
process_group: GroupCoordinator,
|
||||
tensor_metadata: TensorMetadata,
|
||||
) -> torch.Tensor:
|
||||
if not torch.distributed.is_initialized() or process_group.world_size == 1:
|
||||
return {}, []
|
||||
assert src < process_group.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
hidden_states = torch.empty(
|
||||
tensor_metadata.size,
|
||||
dtype=tensor_metadata.dtype,
|
||||
device=tensor_metadata.device,
|
||||
)
|
||||
torch.distributed.recv(
|
||||
hidden_states,
|
||||
src=process_group.ranks[src],
|
||||
group=process_group.device_group,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# attn -> ffn
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def send_attn_output(
|
||||
self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata
|
||||
) -> None:
|
||||
"""
|
||||
Called by ATTN side to send intermediate tensors
|
||||
generated by ATTN instances to FFN.
|
||||
"""
|
||||
try:
|
||||
dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size
|
||||
if metadata.layer_idx == 0 and metadata.stage_idx == 0:
|
||||
self._send_metadata(metadata, hidden_states, dst, self.a2e_group)
|
||||
self._current_afd_connector_metadata = metadata
|
||||
self._send_hidden_states(hidden_states, dst, self.a2e_group)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Communication error: {e}")
|
||||
|
||||
def recv_ffn_output(self) -> torch.Tensor:
|
||||
"""
|
||||
Called by the ATTN side to receive MOE output intermediate tensors,
|
||||
possibly dispatching from the receiver to other GPUs.
|
||||
"""
|
||||
src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size
|
||||
stage_idx = (
|
||||
self.recv_ffn_output_counter
|
||||
% self._current_afd_connector_metadata.num_of_stages
|
||||
)
|
||||
hidden_states = self._recv_hidden_states(
|
||||
src,
|
||||
self.e2a_group,
|
||||
self._tensor_metadata_list[stage_idx],
|
||||
)
|
||||
self.recv_ffn_output_counter = (
|
||||
self.recv_ffn_output_counter + 1
|
||||
) % self._current_afd_connector_metadata.num_of_stages
|
||||
return hidden_states
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# ffn -> attn
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def send_ffn_output(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
metadata: AFDConnectorMetadata,
|
||||
) -> None:
|
||||
"""
|
||||
Called by FFN side to send intermediate tensors generated by FFN
|
||||
instances back to the sender (should be the same GPU as source).
|
||||
"""
|
||||
dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size
|
||||
self._send_hidden_states(hidden_states, dst, self.e2a_group)
|
||||
self.recv_attn_output_counter += 1
|
||||
if (
|
||||
self.recv_attn_output_counter
|
||||
% (
|
||||
self._current_afd_connector_metadata.num_of_stages
|
||||
* self.num_hidden_layers
|
||||
)
|
||||
== 0
|
||||
):
|
||||
self._need_recv_metadata = True
|
||||
self.recv_attn_output_counter = 0
|
||||
|
||||
def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]:
|
||||
"""
|
||||
Called by the FFN side to receive intermediate tensors from ATTN.
|
||||
Handles receiving and possibly dispatching tensors.
|
||||
"""
|
||||
src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size
|
||||
if self._need_recv_metadata:
|
||||
self._recv_metadata(src, self.a2e_group)
|
||||
self._need_recv_metadata = False
|
||||
|
||||
stage_idx = (
|
||||
self.recv_attn_output_counter
|
||||
% self._current_afd_connector_metadata.num_of_stages
|
||||
)
|
||||
layer_idx = (
|
||||
self.recv_attn_output_counter
|
||||
// self._current_afd_connector_metadata.num_of_stages
|
||||
)
|
||||
hidden_states = self._recv_hidden_states(
|
||||
src,
|
||||
self.a2e_group,
|
||||
self._tensor_metadata_list[stage_idx],
|
||||
)
|
||||
self._current_afd_connector_metadata.layer_idx = layer_idx
|
||||
self._current_afd_connector_metadata.stage_idx = stage_idx
|
||||
return hidden_states, self._current_afd_connector_metadata
|
||||
@ -41,6 +41,14 @@ import torch.distributed
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
import torch.distributed._symmetric_memory
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import (
|
||||
PrefixStore,
|
||||
Store,
|
||||
_new_process_group_helper,
|
||||
_world,
|
||||
default_pg_timeout,
|
||||
rendezvous,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
@ -322,7 +330,6 @@ class GroupCoordinator:
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
|
||||
for ranks in group_ranks:
|
||||
device_group = torch.distributed.new_group(
|
||||
ranks, backend=torch_distributed_backend
|
||||
@ -1036,6 +1043,63 @@ _INNER_DP_WORLD: GroupCoordinator | None = None
|
||||
_NODE_COUNT: int | None = None
|
||||
|
||||
|
||||
def init_afd_process_group(
|
||||
backend: str | Backend = None,
|
||||
init_method: str | None = None,
|
||||
timeout: timedelta | None = None,
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
store: Store | None = None,
|
||||
group_name: str = None,
|
||||
pg_options: Any | None = None,
|
||||
):
|
||||
assert (store is None) or (init_method is None), (
|
||||
"Cannot specify both init_method and store."
|
||||
)
|
||||
|
||||
if store is not None:
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
if backend:
|
||||
backend = Backend(backend)
|
||||
else:
|
||||
backend = Backend("undefined")
|
||||
|
||||
if timeout is None:
|
||||
timeout = default_pg_timeout
|
||||
|
||||
if store is None:
|
||||
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
store = PrefixStore(group_name, store)
|
||||
|
||||
pg_options_param_name = (
|
||||
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
||||
)
|
||||
pg, _ = _new_process_group_helper(
|
||||
world_size,
|
||||
rank,
|
||||
[],
|
||||
backend,
|
||||
store,
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
)
|
||||
global _AFD
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
_AFD = pg
|
||||
return pg
|
||||
|
||||
|
||||
_WORLD: GroupCoordinator | None = None
|
||||
_NODE_COUNT: int | None = None
|
||||
|
||||
|
||||
def get_world_group() -> GroupCoordinator:
|
||||
assert _WORLD is not None, "world group is not initialized"
|
||||
return _WORLD
|
||||
|
||||
@ -35,6 +35,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
AFDConfig,
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
ConfigType,
|
||||
@ -73,7 +74,7 @@ from vllm.config.model import (
|
||||
RunnerOption,
|
||||
TokenizerMode,
|
||||
)
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
|
||||
from vllm.config.observability import DetailedTraceModules
|
||||
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
|
||||
from vllm.config.scheduler import SchedulerPolicy
|
||||
@ -574,6 +575,8 @@ class EngineArgs:
|
||||
CacheConfig.kv_offloading_backend
|
||||
)
|
||||
tokens_only: bool = False
|
||||
# AFD config
|
||||
afd_config: AFDConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
@ -1153,6 +1156,8 @@ class EngineArgs:
|
||||
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
|
||||
)
|
||||
vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"])
|
||||
vllm_group.add_argument("--afd-config", **vllm_kwargs["afd_config"])
|
||||
|
||||
vllm_group.add_argument(
|
||||
"--optimization-level", **vllm_kwargs["optimization_level"]
|
||||
)
|
||||
@ -1732,6 +1737,7 @@ class EngineArgs:
|
||||
profiler_config=self.profiler_config,
|
||||
additional_config=self.additional_config,
|
||||
optimization_level=self.optimization_level,
|
||||
afd_config=self.afd_config,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
91
vllm/entrypoints/afd_ffn_server.py
Normal file
91
vllm/entrypoints/afd_ffn_server.py
Normal file
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""vLLM AFD FFN Server Entry Point
|
||||
|
||||
This script provides a standalone entry point for running FFN servers in an AFD
|
||||
(Attention-FFN Disaggregation) setup. FFN servers handle remote FFN computation
|
||||
for attention workers.
|
||||
|
||||
Usage:
|
||||
python -m vllm.entrypoints.afd_ffn_server /path/to/model \
|
||||
--tensor-parallel-size 8 \
|
||||
--afd-config '{"afd_connector": "dummy", "afd_role": "ffn"}' \
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AFDFFNServer:
|
||||
"""AFD FFN Server main class."""
|
||||
|
||||
def __init__(self, args: Any):
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
self.vllm_config = engine_args.create_engine_config()
|
||||
logger.info("Start AFD FFN Server with vllm_config: %s", self.vllm_config)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the AFD FFN server."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
# Create configurations
|
||||
executor_class = Executor.get_class(self.vllm_config)
|
||||
self.model_executor = executor_class(vllm_config=self.vllm_config)
|
||||
# Start the FFN server loop
|
||||
self._run_server_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start AFD FFN server: %s", e)
|
||||
raise
|
||||
|
||||
def _run_server_loop(self) -> None:
|
||||
"""Start FFN workers and wait for completion"""
|
||||
logger.info("AFD FFN Server started, workers running...")
|
||||
try:
|
||||
# Tell workers to start FFN server loops (one-time call)
|
||||
self.model_executor.collective_rpc("start_ffn_server_loop")
|
||||
|
||||
# Main thread waits without busy polling
|
||||
shutdown_event = threading.Event()
|
||||
shutdown_event.wait() # Block until interrupted
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutting down...")
|
||||
self.model_executor.collective_rpc("stop_ffn_server_loop")
|
||||
except Exception as e:
|
||||
logger.error("Server error: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def main(args: Any) -> None:
|
||||
"""Main entry point for AFD FFN server."""
|
||||
try:
|
||||
server = AFDFFNServer(args)
|
||||
server.start()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error("Server error: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add model as positional argument (like vllm serve)
|
||||
parser.add_argument("model", type=str, help="Model name or path")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set the model from positional argument
|
||||
args.model = args.model
|
||||
|
||||
main(args)
|
||||
51
vllm/entrypoints/cli/fserver.py
Normal file
51
vllm/entrypoints/cli/fserver.py
Normal file
@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""vLLM AFD FFN Server CLI command."""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.afd_ffn_server import main
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
|
||||
|
||||
class FServerCommand(CLISubcommand):
|
||||
"""Command for running vLLM AFD FFN Server."""
|
||||
|
||||
def __init__(self):
|
||||
self.name = "fserver"
|
||||
super().__init__()
|
||||
|
||||
def subparser_init(self, subparsers):
|
||||
"""Initialize the fserver subparser."""
|
||||
parser = subparsers.add_parser(
|
||||
self.name,
|
||||
help="Start vLLM AFD FFN Server",
|
||||
description="Start vLLM AFD FFN Server for Attention-FFN Disaggregation",
|
||||
usage="vllm fserver MODEL --afd-config CONFIG [options]",
|
||||
)
|
||||
|
||||
# Add model as positional argument (like vllm serve)
|
||||
parser.add_argument("model", type=str, help="Model name or path")
|
||||
|
||||
# Use AsyncEngineArgs to add all vLLM engine arguments
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
"""Validate arguments for fserver command."""
|
||||
# Validate that afd_config is provided
|
||||
if not hasattr(args, "afd_config") or not args.afd_config:
|
||||
raise ValueError("--afd-config is required for FFN server")
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
"""Run the fserver command."""
|
||||
# Call the main function from afd_ffn_server directly with parsed args
|
||||
main(args)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
"""Initialize fserver command."""
|
||||
return [FServerCommand()]
|
||||
@ -16,6 +16,7 @@ logger = init_logger(__name__)
|
||||
def main():
|
||||
import vllm.entrypoints.cli.benchmark.main
|
||||
import vllm.entrypoints.cli.collect_env
|
||||
import vllm.entrypoints.cli.fserver
|
||||
import vllm.entrypoints.cli.openai
|
||||
import vllm.entrypoints.cli.run_batch
|
||||
import vllm.entrypoints.cli.serve
|
||||
@ -25,6 +26,7 @@ def main():
|
||||
CMD_MODULES = [
|
||||
vllm.entrypoints.cli.openai,
|
||||
vllm.entrypoints.cli.serve,
|
||||
vllm.entrypoints.cli.fserver,
|
||||
vllm.entrypoints.cli.benchmark.main,
|
||||
vllm.entrypoints.cli.collect_env,
|
||||
vllm.entrypoints.cli.run_batch,
|
||||
|
||||
@ -195,7 +195,6 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
|
||||
|
||||
api_server_manager: APIServerProcessManager | None = None
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses):
|
||||
|
||||
@ -1325,7 +1325,6 @@ async def run_server_worker(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
@ -12,7 +12,9 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.distributed.afd_transfer import AFDConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||
|
||||
@ -94,6 +96,46 @@ class DPMetadata:
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: list[int] | None = None
|
||||
|
||||
@staticmethod
|
||||
def num_stage_tokens_across_dp(
|
||||
num_stage_tokens: list[int], dp_size: int, dp_rank: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gather the stage token counts across all DP ranks.
|
||||
Args:
|
||||
num_stage_tokens: list of token counts per stage for current rank
|
||||
dp_size: number of DP ranks
|
||||
dp_rank: current DP rank
|
||||
Returns:
|
||||
stage_tokens_across_dp_cpu: [num_stages, dp_size] tensor
|
||||
max_stage_tokens_across_dp_cpu: [num_stages] tensor with max
|
||||
tokens per stage
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
num_stages = len(num_stage_tokens)
|
||||
stage_tokens_across_dp = torch.zeros(
|
||||
(num_stages, dp_size), device=device, dtype=torch.int32
|
||||
)
|
||||
stage_tokens_across_dp[:, dp_rank] = torch.tensor(
|
||||
num_stage_tokens, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# AllReduce to gather from all ranks
|
||||
dist.all_reduce(stage_tokens_across_dp, group=group)
|
||||
stage_tokens_across_dp_cpu = stage_tokens_across_dp.cpu()
|
||||
|
||||
# Compute max tokens per stage
|
||||
max_stage_tokens_across_dp_cpu = torch.max(stage_tokens_across_dp_cpu, dim=1)[0]
|
||||
|
||||
return stage_tokens_across_dp_cpu, max_stage_tokens_across_dp_cpu
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
@ -182,6 +224,23 @@ class DPMetadata:
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AFDMetadata:
|
||||
afd_tokens_start_loc: list[int]
|
||||
afd_reqs_start_loc: list[int]
|
||||
afd_stage_idx: int
|
||||
afd_connector: "AFDConnectorBase"
|
||||
afd_tokens_lens: list[int] # padded lengths for tensor slicing
|
||||
num_of_stages: int
|
||||
|
||||
input_ids_list: list[torch.Tensor] = field(default_factory=list)
|
||||
positions_list: list[torch.Tensor] = field(default_factory=list)
|
||||
inputs_embeds_list: list[torch.Tensor] = field(default_factory=list)
|
||||
intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list)
|
||||
attn_metadata_list: list[AttentionMetadata] = field(default_factory=list)
|
||||
dp_metadata_list: list[DPMetadata] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
@ -198,6 +257,7 @@ class ForwardContext:
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
# set dynamically for each forward pass
|
||||
dp_metadata: DPMetadata | None = None
|
||||
afd_metadata: AFDMetadata | None = None
|
||||
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
|
||||
# by default NONE, no cudagraph is used.
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
|
||||
@ -235,6 +295,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
afd_metadata: AFDMetadata | None = None,
|
||||
):
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
@ -244,6 +305,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -272,6 +334,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
afd_metadata: AFDMetadata | None = None,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
@ -317,6 +380,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode,
|
||||
batch_descriptor,
|
||||
ubatch_slices,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -594,7 +594,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
|
||||
if self.gather_output and self.tp_size > 1:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
|
||||
@ -45,6 +45,7 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -101,6 +102,9 @@ elif current_platform.is_xpu():
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from vllm.forward_context import AFDMetadata
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield
|
||||
|
||||
|
||||
class DeepseekAttention(nn.Module):
|
||||
"""Normal MHA implementation used by Deepseek v1."""
|
||||
@ -383,7 +387,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states
|
||||
)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
@ -1123,6 +1126,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
afd_config = vllm_config.afd_config
|
||||
self.afd_role = afd_config.afd_role if afd_config is not None else None
|
||||
self.hidden_size = config.hidden_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
|
||||
@ -1148,42 +1153,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
attn_cls = DeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % moe_layer_freq == 0
|
||||
):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
if self.afd_role is None or self.afd_role == "attention":
|
||||
self.self_attn = attn_cls(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
parallel_config=parallel_config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank
|
||||
if hasattr(config, "q_lora_rank")
|
||||
else None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=f"{prefix}.self_attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
if self.afd_role is None or self.afd_role == "ffn":
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % moe_layer_freq == 0
|
||||
):
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
parallel_config=parallel_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@ -1227,6 +1237,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
if self.afd_role == "attention":
|
||||
return hidden_states, residual
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
@ -1239,6 +1252,49 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def compute_attn_output(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> torch.Tensor: # Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# We scale both hidden_states and residual before
|
||||
# rmsnorm, and rmsnorm result would not affect by scale.
|
||||
hidden_states *= 1.0 / self.routed_scaling_factor
|
||||
if self.layer_idx == 0:
|
||||
# The residual is shared by all layers, we only scale it on
|
||||
# first layer.
|
||||
residual *= 1.0 / self.routed_scaling_factor
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def compute_ffn_output(self, hidden_states):
|
||||
assert self.afd_role == "ffn"
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# Scaling the DeepseekV2MLP output, it is the input of
|
||||
# input_layernorm of next decoder layer.
|
||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||
# of DeepseekV2MOE
|
||||
hidden_states *= 1.0 / self.routed_scaling_factor
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepseekV2Model(nn.Module):
|
||||
@ -1293,6 +1349,112 @@ class DeepseekV2Model(nn.Module):
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward_with_afd(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
afd_metadata: AFDMetadata,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
afd_connector = afd_metadata.afd_connector
|
||||
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
|
||||
|
||||
if layer.layer_idx > 0:
|
||||
hidden_states = afd_connector.recv_ffn_output()
|
||||
|
||||
current_hidden, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
metadata = AFDConnectorMetadata.create_attention_metadata(
|
||||
layer_idx=layer.layer_idx,
|
||||
stage_idx=afd_metadata.afd_stage_idx,
|
||||
seq_len=current_hidden.shape[0],
|
||||
dtype=current_hidden.dtype,
|
||||
device=current_hidden.device,
|
||||
num_of_stages=afd_metadata.num_of_stages,
|
||||
afd_tokens_lens=afd_metadata.afd_tokens_lens,
|
||||
)
|
||||
afd_connector.send_attn_output(current_hidden, metadata)
|
||||
|
||||
if dbo_enabled():
|
||||
dbo_yield()
|
||||
|
||||
hidden_states = afd_connector.recv_ffn_output()
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_with_afd_v2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
afd_metadata: AFDMetadata,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_conext = get_forward_context()
|
||||
|
||||
ubatch_hidden_states = []
|
||||
ubatch_residual = []
|
||||
|
||||
start_idx = 0
|
||||
for pos in afd_metadata.positions_list:
|
||||
# DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2
|
||||
# Otherwise use shape[0] as requested
|
||||
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
|
||||
end_idx = start_idx + num_tokens
|
||||
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
|
||||
ubatch_residual.append(
|
||||
residual[start_idx:end_idx] if residual is not None else None
|
||||
)
|
||||
start_idx = end_idx
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
|
||||
afd_connector = afd_metadata.afd_connector
|
||||
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
|
||||
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
|
||||
|
||||
residual = ubatch_residual[stage_i]
|
||||
|
||||
if layer.layer_idx > 0:
|
||||
hidden_states = afd_connector.recv_ffn_output()
|
||||
else:
|
||||
hidden_states = ubatch_hidden_states[stage_i]
|
||||
|
||||
current_positions = afd_metadata.positions_list[stage_i]
|
||||
hidden_states, residual = layer(
|
||||
current_positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
|
||||
ubatch_hidden_states[stage_i] = hidden_states
|
||||
ubatch_residual[stage_i] = residual
|
||||
|
||||
metadata = AFDConnectorMetadata.create_attention_metadata(
|
||||
layer_idx=layer.layer_idx,
|
||||
stage_idx=stage_i,
|
||||
seq_len=hidden_states.shape[0],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
num_of_stages=afd_metadata.num_of_stages,
|
||||
afd_tokens_lens=afd_metadata.afd_tokens_lens,
|
||||
)
|
||||
afd_connector.send_attn_output(hidden_states, metadata)
|
||||
|
||||
# Recv last layer FFN output.
|
||||
for stage_i in range(afd_metadata.num_of_stages):
|
||||
ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
|
||||
|
||||
# Re-assemble the batch
|
||||
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
||||
if any(r is not None for r in ubatch_residual):
|
||||
residual = torch.cat(ubatch_residual, dim=0)
|
||||
else:
|
||||
residual = None
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1325,10 +1487,18 @@ class DeepseekV2Model(nn.Module):
|
||||
else:
|
||||
llama_4_scaling = None
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
forward_ctx = get_forward_context()
|
||||
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
|
||||
|
||||
if afd_metadata != None:
|
||||
hidden_states, residual = self.forward_with_afd_v2(
|
||||
hidden_states, residual, positions, afd_metadata, llama_4_scaling
|
||||
)
|
||||
else:
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@ -1338,6 +1508,12 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def compute_ffn_output(
|
||||
self, hidden_states, layer_idx
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
|
||||
moe_mlp_layers: list[DeepseekV2MoE]
|
||||
@ -1404,6 +1580,10 @@ class DeepseekV2ForCausalLM(
|
||||
if self.use_mha:
|
||||
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
||||
|
||||
self.afd_config = vllm_config.afd_config
|
||||
self.afd_role = (
|
||||
self.afd_config.afd_role if self.afd_config is not None else None
|
||||
)
|
||||
# `packed_modules_mapping` needs to be modified before
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
@ -1452,12 +1632,16 @@ class DeepseekV2ForCausalLM(
|
||||
continue
|
||||
|
||||
assert isinstance(layer, DeepseekV2DecoderLayer)
|
||||
if isinstance(layer.mlp, DeepseekV2MoE):
|
||||
if (self.afd_role is None or self.afd_role == "ffn") and isinstance(
|
||||
layer.mlp, DeepseekV2MoE
|
||||
):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.mlp
|
||||
self.moe_mlp_layers.append(layer.mlp)
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
if self.afd_role == "attention":
|
||||
return
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@ -1475,6 +1659,12 @@ class DeepseekV2ForCausalLM(
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_ffn_output(
|
||||
self, hidden_states, current_layer_idx
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1518,6 +1708,13 @@ class DeepseekV2ForCausalLM(
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
if self.afd_role == "attention":
|
||||
vllm_config = get_current_vllm_config()
|
||||
num_redundant_experts = (
|
||||
vllm_config.parallel_config.eplb_config.num_redundant_experts
|
||||
)
|
||||
else:
|
||||
num_redundant_experts = self.num_redundant_experts
|
||||
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
@ -1528,7 +1725,7 @@ class DeepseekV2ForCausalLM(
|
||||
if rocm_aiter_moe_shared_expert_enabled
|
||||
else 0
|
||||
),
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -1537,6 +1734,9 @@ class DeepseekV2ForCausalLM(
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if self.afd_role == "attention" and self.is_moe_weight(name):
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
@ -1640,7 +1840,8 @@ class DeepseekV2ForCausalLM(
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
if self.afd_role is not None and self.afd_role == "attention":
|
||||
continue
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = chunk_name.replace(weight_name, param_name)
|
||||
@ -1670,6 +1871,12 @@ class DeepseekV2ForCausalLM(
|
||||
loaded_params.add(name_mapped)
|
||||
break
|
||||
else:
|
||||
if (
|
||||
self.afd_role == "ffn"
|
||||
and not self.is_moe_weight(name)
|
||||
and not self.is_common_weight(name)
|
||||
):
|
||||
continue
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
@ -1698,6 +1905,29 @@ class DeepseekV2ForCausalLM(
|
||||
|
||||
return loaded_params
|
||||
|
||||
def is_moe_weight(self, name):
|
||||
if (
|
||||
"shared_experts" in name
|
||||
or "experts" in name
|
||||
or "gate" in name
|
||||
or "up" in name
|
||||
or "down" in name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_common_weight(self, name):
|
||||
if (
|
||||
"lm_head" in name
|
||||
or "model.norm.weight" in name
|
||||
or "embed_tokens" in name
|
||||
or "input_layernorm" in name
|
||||
or "post_attention_layernorm" in name
|
||||
):
|
||||
# or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
@ -11,12 +11,16 @@ from torch import nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed.afd_transfer.afd_connector.metadata import (
|
||||
AFDConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import AFDMetadata, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.step3_vl import Step3TextConfig
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (
|
||||
@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module):
|
||||
config: Step3TextConfig,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
afd_config: AFDConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.afd_role = afd_config.afd_role if afd_config is not None else None
|
||||
|
||||
self.self_attn = Step3TextAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
norm_eps=config.rms_norm_eps,
|
||||
max_position_embedding=config.max_position_embedding,
|
||||
head_dim=config.head_dim,
|
||||
share_q_dim=config.share_q_dim,
|
||||
rope_parameters=config.rope_parameters,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
|
||||
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
||||
if moe_layers_enum is not None:
|
||||
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
|
||||
else:
|
||||
# Default to 1dense.
|
||||
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
||||
|
||||
if layer_idx in moe_layers_idx:
|
||||
self.moe = FusedMoEBlock(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
|
||||
)
|
||||
self.share_expert = Step3TextMLP(
|
||||
if self.afd_role is None or self.afd_role == "attention":
|
||||
self.self_attn = Step3TextAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.share_expert_dim,
|
||||
hidden_act="silu",
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.share_expert",
|
||||
norm_eps=config.rms_norm_eps,
|
||||
max_position_embedding=config.max_position_embedding,
|
||||
head_dim=config.head_dim,
|
||||
share_q_dim=config.share_q_dim,
|
||||
rope_parameters=config.rope_parameters,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.use_moe = True
|
||||
else:
|
||||
self.mlp = Step3TextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.use_moe = False
|
||||
|
||||
self.layer_idx = int(prefix.split("layers.")[1].split(".")[0])
|
||||
|
||||
if self.afd_role is None or self.afd_role == "ffn":
|
||||
moe_layers_enum = getattr(config, "moe_layers_enum", None)
|
||||
if moe_layers_enum is not None:
|
||||
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
|
||||
else:
|
||||
# Default to 1dense.
|
||||
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
|
||||
|
||||
if self.layer_idx in moe_layers_idx:
|
||||
self.moe = FusedMoEBlock(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
|
||||
)
|
||||
self.share_expert = Step3TextMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.share_expert_dim,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.share_expert",
|
||||
)
|
||||
self.use_moe = True
|
||||
else:
|
||||
self.mlp = Step3TextMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.use_moe = False
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module):
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
if self.afd_role == "attention":
|
||||
return hidden_states, residual
|
||||
|
||||
if self.use_moe:
|
||||
share_output = self.share_expert(hidden_states)
|
||||
moe_output = self.moe(hidden_states)
|
||||
@ -309,6 +322,16 @@ class Step3TextDecoderLayer(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def compute_ffn_output(self, hidden_states):
|
||||
assert self.afd_role == "ffn"
|
||||
if self.use_moe:
|
||||
share_output = self.share_expert(hidden_states)
|
||||
moe_output = self.moe(hidden_states)
|
||||
hidden_states = share_output + moe_output
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Step3TextModel(nn.Module):
|
||||
@ -317,6 +340,7 @@ class Step3TextModel(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
afd_config = vllm_config.afd_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
@ -336,6 +360,7 @@ class Step3TextModel(nn.Module):
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
afd_config=afd_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
@ -352,6 +377,72 @@ class Step3TextModel(nn.Module):
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward_with_afd(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
afd_metadata: AFDMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_conext = get_forward_context()
|
||||
|
||||
ubatch_hidden_states = []
|
||||
ubatch_residual = []
|
||||
|
||||
start_idx = 0
|
||||
for pos in afd_metadata.positions_list:
|
||||
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
|
||||
end_idx = start_idx + num_tokens
|
||||
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
|
||||
ubatch_residual.append(
|
||||
residual[start_idx:end_idx] if residual is not None else None
|
||||
)
|
||||
start_idx = end_idx
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
|
||||
afd_connector = afd_metadata.afd_connector
|
||||
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
|
||||
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
|
||||
|
||||
residual = ubatch_residual[stage_i]
|
||||
|
||||
if layer.layer_idx > 0:
|
||||
hidden_states = afd_connector.recv_ffn_output()
|
||||
else:
|
||||
hidden_states = ubatch_hidden_states[stage_i]
|
||||
|
||||
current_positions = afd_metadata.positions_list[stage_i]
|
||||
hidden_states, residual = layer(
|
||||
current_positions, hidden_states, residual
|
||||
)
|
||||
|
||||
ubatch_hidden_states[stage_i] = hidden_states
|
||||
ubatch_residual[stage_i] = residual
|
||||
metadata = AFDConnectorMetadata.create_attention_metadata(
|
||||
layer_idx=layer.layer_idx,
|
||||
stage_idx=stage_i,
|
||||
seq_len=hidden_states.shape[0],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
num_of_stages=afd_metadata.num_of_stages,
|
||||
afd_tokens_lens=afd_metadata.afd_tokens_lens,
|
||||
)
|
||||
afd_connector.send_attn_output(hidden_states, metadata)
|
||||
|
||||
# Recv last layer FFN output.
|
||||
for stage_i in range(afd_metadata.num_of_stages):
|
||||
ubatch_hidden_states[stage_i] = afd_connector.recv_ffn_output()
|
||||
|
||||
# Re-assemble the batch
|
||||
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
||||
if any(r is not None for r in ubatch_residual):
|
||||
residual = torch.cat(ubatch_residual, dim=0)
|
||||
else:
|
||||
residual = None
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -370,8 +461,19 @@ class Step3TextModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
forward_ctx = get_forward_context()
|
||||
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
|
||||
|
||||
if afd_metadata is not None:
|
||||
hidden_states, residual = self.forward_with_afd(
|
||||
hidden_states,
|
||||
residual,
|
||||
positions,
|
||||
afd_metadata,
|
||||
)
|
||||
else:
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@ -384,6 +486,14 @@ class Step3TextModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def compute_ffn_output(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_idx,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
def __init__(
|
||||
@ -398,6 +508,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
self.config = config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self.afd_config = vllm_config.afd_config
|
||||
self.afd_role = (
|
||||
self.afd_config.afd_role if self.afd_config is not None else None
|
||||
)
|
||||
|
||||
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
@ -429,6 +544,14 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_ffn_output(
|
||||
self,
|
||||
hidden_states,
|
||||
current_layer_idx,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
@ -477,9 +600,13 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if self.afd_role == "attention" and self.is_moe_weight(name):
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if any(
|
||||
disable_moe_stacked_param in name
|
||||
for disable_moe_stacked_param in disable_moe_stacked_params
|
||||
@ -498,6 +625,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
param_name, weight_name, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if self.afd_role is not None and self.afd_role == "attention":
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@ -521,6 +652,12 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if (
|
||||
self.afd_role == "ffn"
|
||||
and not self.is_moe_weight(name)
|
||||
and not self.is_common_weight(name)
|
||||
):
|
||||
continue
|
||||
for (
|
||||
param_name,
|
||||
weight_name,
|
||||
@ -552,3 +689,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def is_moe_weight(self, name):
|
||||
if (
|
||||
"shared_expert" in name
|
||||
or "experts" in name
|
||||
or "gate" in name
|
||||
or "up" in name
|
||||
or "down" in name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_common_weight(self, name):
|
||||
if (
|
||||
"lm_head" in name
|
||||
or "model.norm.weight" in name
|
||||
or "embed" in name
|
||||
or "input_layernorm" in name
|
||||
or "post_attention_layernorm" in name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_ffn_output(
|
||||
self,
|
||||
hidden_states,
|
||||
current_layer_idx,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.language_model.compute_ffn_output(
|
||||
hidden_states, current_layer_idx
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -105,6 +105,10 @@ class EngineCore:
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(executor_fail_callback)
|
||||
|
||||
self.afd_config = vllm_config.afd_config
|
||||
if self.afd_config and self.afd_config.afd_role == "ffn":
|
||||
return
|
||||
|
||||
self.available_gpu_memory_for_kv_cache = -1
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
@ -614,6 +618,7 @@ class EngineCoreProc(EngineCore):
|
||||
executor_fail_callback = lambda: self.input_queue.put_nowait(
|
||||
(EngineCoreRequestType.EXECUTOR_FAILED, b"")
|
||||
)
|
||||
self.afd_config = vllm_config.afd_config
|
||||
|
||||
self.engine_index = engine_index
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
@ -868,7 +873,6 @@ class EngineCoreProc(EngineCore):
|
||||
set_process_title("EngineCore")
|
||||
decorate_logs()
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
@ -891,6 +895,23 @@ class EngineCoreProc(EngineCore):
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
if self.afd_config and self.afd_config.afd_role == "ffn":
|
||||
logger.info("AFD FFN Server started, workers running...")
|
||||
try:
|
||||
# Tell workers to start FFN server loops (one-time call)
|
||||
self.model_executor.collective_rpc("start_ffn_server_loop")
|
||||
|
||||
# Main thread waits without busy polling
|
||||
shutdown_event = threading.Event()
|
||||
shutdown_event.wait() # Block until interrupted
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutting down...")
|
||||
self.model_executor.collective_rpc("stop_ffn_server_loop")
|
||||
except Exception as e:
|
||||
logger.error("Server error: %s", e)
|
||||
raise
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
@ -1205,6 +1226,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.afd_config = vllm_config.afd_config
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
local_client,
|
||||
@ -1287,6 +1309,22 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
if self.afd_config and self.afd_config.afd_role == "ffn":
|
||||
logger.info("AFD FFN Server started, workers running...")
|
||||
try:
|
||||
# Tell workers to start FFN server loops (one-time call)
|
||||
self.model_executor.collective_rpc("start_ffn_server_loop")
|
||||
|
||||
# Main thread waits without busy polling
|
||||
shutdown_event = threading.Event()
|
||||
shutdown_event.wait() # Block until interrupted
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server shutting down...")
|
||||
self.model_executor.collective_rpc("stop_ffn_server_loop")
|
||||
except Exception as e:
|
||||
logger.error("Server error: %s", e)
|
||||
raise
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
|
||||
@ -16,7 +16,7 @@ import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import AFDConfig, CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
@ -908,6 +908,7 @@ def launch_core_engines(
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
vllm_config.afd_config,
|
||||
)
|
||||
|
||||
|
||||
@ -919,6 +920,7 @@ def wait_for_engine_startup(
|
||||
cache_config: CacheConfig,
|
||||
proc_manager: CoreEngineProcManager | None,
|
||||
coord_process: Process | None,
|
||||
afd_config: AFDConfig | None = None,
|
||||
):
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
@ -1020,6 +1022,13 @@ def wait_for_engine_startup(
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif (
|
||||
status == "READY"
|
||||
and engine.state == CoreEngineState.CONNECTED
|
||||
and afd_config
|
||||
and afd_config.afd_role == "ffn"
|
||||
):
|
||||
engine.state = CoreEngineState.READY
|
||||
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
|
||||
@ -213,6 +213,9 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
self.afd_config = self.vllm_config.afd_config
|
||||
self.afd_role = self.afd_config.afd_role if self.afd_config else None
|
||||
|
||||
def start_worker_monitor(self, inline=False) -> None:
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
@ -565,6 +568,9 @@ class WorkerProc:
|
||||
# environment variable overrides after this point)
|
||||
enable_envs_cache()
|
||||
|
||||
self.afd_config = vllm_config.afd_config
|
||||
self.afd_role = self.afd_config.afd_role if self.afd_config else None
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
450
vllm/v1/worker/gpu_ffn_model_runner.py
Normal file
450
vllm/v1/worker/gpu_ffn_model_runner.py
Normal file
@ -0,0 +1,450 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.afd_transfer.afd_connector.factory import AFDConnectorFactory
|
||||
from vllm.distributed.afd_transfer.afd_connector.metadata import AFDConnectorMetadata
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_world_group,
|
||||
graph_capture,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUFFNModelRunner(LoRAModelRunnerMixin):
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
self.dtype = self.model_config.dtype
|
||||
self.load_config = vllm_config.load_config
|
||||
|
||||
self.afd_config = vllm_config.afd_config
|
||||
if not self.afd_config or not self.afd_config.is_ffn_server:
|
||||
raise ValueError(
|
||||
"AFD config must be provided with afd_role='ffn' for FFN server"
|
||||
)
|
||||
|
||||
self._counter = 0
|
||||
|
||||
# Initialize torch.profile for performance monitoring
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=6000 * 27 + 4000 * 27 * 2, warmup=1, active=30, repeat=1
|
||||
),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
"./profiler_logs/ffn"
|
||||
),
|
||||
record_shapes=True,
|
||||
profile_memory=False,
|
||||
with_stack=False,
|
||||
)
|
||||
|
||||
# Initialize CUDA graph support
|
||||
self.use_cuda_graph = not self.model_config.enforce_eager
|
||||
|
||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||
# The batch sizes in the config are in descending order.
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)
|
||||
)
|
||||
|
||||
# Storage for captured graphs
|
||||
self._cuda_graphs: dict[
|
||||
tuple[int, int], torch.cuda.CUDAGraph
|
||||
] = {} # {(layer_idx, num_tokens): CUDAGraph}
|
||||
self._graph_memory_pool = None
|
||||
|
||||
assert self.afd_config.is_ffn_server
|
||||
self.connector = AFDConnectorFactory.create_connector(
|
||||
get_world_group().rank, get_world_group().local_rank, self.vllm_config
|
||||
)
|
||||
|
||||
if getattr(self.model_config.hf_config, "text_config", None) is not None:
|
||||
self.num_layers = self.model_config.hf_config.text_config.num_hidden_layers
|
||||
else:
|
||||
self.num_layers = self.model_config.hf_config.num_hidden_layers
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def initialize_afd_connector(self) -> None:
|
||||
self.connector.init_afd_connector()
|
||||
|
||||
def load_model(self, **kwargs) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
time_before_load = time.perf_counter()
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config, model_config=self.model_config
|
||||
)
|
||||
else:
|
||||
logger.info("Model was already initialized. Loading weights inplace...")
|
||||
model_loader.load_weights(self.model, model_config=self.model_config)
|
||||
time_after_load = time.perf_counter()
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info(
|
||||
"Model loading took %.4f GiB and %.6f seconds",
|
||||
self.model_memory_usage / GiB_bytes,
|
||||
time_after_load - time_before_load,
|
||||
)
|
||||
|
||||
logger.info("AFD FFN Model loaded successfully")
|
||||
|
||||
def _get_current_layer_idx(self) -> int:
|
||||
return (self._counter // self.afd_config.num_afd_stages) % self.num_layers
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(self, scheduler_output=None, intermediate_tensors=None):
|
||||
"""Execute FFN computation for a single request"""
|
||||
# scheduler_output and intermediate_tensors are unused in FFN server
|
||||
# mode
|
||||
self.profiler.step()
|
||||
|
||||
try:
|
||||
hidden_states, recv_metadata = self.connector.recv_attn_output()
|
||||
if hasattr(self.connector, "dp_metadata_list"):
|
||||
dp_metadata = self.connector.dp_metadata_list.get(
|
||||
recv_metadata.stage_idx, None
|
||||
)
|
||||
else:
|
||||
dp_metadata = None
|
||||
current_layer_idx = recv_metadata.layer_idx
|
||||
# logger.info(
|
||||
# f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}"
|
||||
# f" dp_metadata: {dp_metadata}"
|
||||
# )
|
||||
num_tokens = hidden_states.shape[0]
|
||||
if recv_metadata is not None and recv_metadata.recv_handle_list is not None:
|
||||
for work in recv_metadata.recv_handle_list:
|
||||
work.wait()
|
||||
# Try to use CUDA graph if available
|
||||
cuda_graph_info = self._find_cuda_graph(current_layer_idx, num_tokens)
|
||||
if cuda_graph_info is not None:
|
||||
# Use captured CUDA graph for computation
|
||||
with set_forward_context(
|
||||
attn_metadata=None, vllm_config=self.vllm_config
|
||||
):
|
||||
get_forward_context().dp_metadata = dp_metadata
|
||||
rank_ffn_output = self._execute_with_cuda_graph(
|
||||
hidden_states, cuda_graph_info
|
||||
)
|
||||
else:
|
||||
# Fallback to eager mode
|
||||
with set_forward_context(
|
||||
attn_metadata=None, vllm_config=self.vllm_config
|
||||
):
|
||||
get_forward_context().dp_metadata = dp_metadata
|
||||
rank_ffn_output = self._execute_eager_mode(
|
||||
hidden_states, current_layer_idx
|
||||
)
|
||||
|
||||
recv_metadata.recv_handle_list = None
|
||||
self.connector.send_ffn_output(rank_ffn_output, recv_metadata)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error computing FFN: {e}") from e
|
||||
finally:
|
||||
self._counter += 1
|
||||
if self._counter == self.num_layers * self.afd_config.num_afd_stages:
|
||||
self._counter = 0
|
||||
return None # FFN server doesn't return ModelRunnerOutput
|
||||
|
||||
def _execute_with_cuda_graph(
|
||||
self, hidden_states: torch.Tensor, cuda_graph_info: dict
|
||||
):
|
||||
"""Execute FFN computation using captured CUDA graph."""
|
||||
graph = cuda_graph_info["graph"]
|
||||
input_tensor = cuda_graph_info["input_hidden_states"]
|
||||
output_tensor = cuda_graph_info["output"]
|
||||
|
||||
# Copy input data to graph's input tensor
|
||||
# Handle padding if necessary
|
||||
actual_tokens = hidden_states.shape[0]
|
||||
graph_tokens = input_tensor.shape[0]
|
||||
|
||||
if actual_tokens <= graph_tokens:
|
||||
# Copy actual data and pad with zeros if needed
|
||||
input_tensor[:actual_tokens].copy_(hidden_states)
|
||||
if actual_tokens < graph_tokens:
|
||||
input_tensor[actual_tokens:].zero_()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Input size {actual_tokens} exceeds graph capacity {graph_tokens}"
|
||||
)
|
||||
|
||||
# Replay the captured graph
|
||||
graph.replay()
|
||||
|
||||
# Return only the actual output (without padding)
|
||||
return output_tensor[:actual_tokens].clone()
|
||||
|
||||
def _execute_eager_mode(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
current_layer_idx: int,
|
||||
recv_metadata: AFDConnectorMetadata = None,
|
||||
):
|
||||
"""Execute FFN computation in eager mode (fallback)."""
|
||||
# Step the profiler for performance monitoring
|
||||
|
||||
# Handle TP case: all-gather tensors from all TP ranks
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
if tp_world_size > 1:
|
||||
# All-gather hidden states from all TP ranks
|
||||
gathered_hidden_states = tensor_model_parallel_all_gather(
|
||||
hidden_states, dim=0
|
||||
)
|
||||
ffn_output = self.model.compute_ffn_output(
|
||||
gathered_hidden_states, current_layer_idx
|
||||
)
|
||||
# Extract the output corresponding to current rank
|
||||
start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank()
|
||||
end_idx = start_idx + hidden_states.shape[0]
|
||||
rank_ffn_output = ffn_output[start_idx:end_idx, :]
|
||||
else:
|
||||
# Single TP case
|
||||
rank_ffn_output = self.model.compute_ffn_output(
|
||||
hidden_states, current_layer_idx
|
||||
)
|
||||
|
||||
return rank_ffn_output
|
||||
|
||||
# Methods required for interface compatibility with GPUModelRunner
|
||||
def profile_run(self) -> None:
|
||||
"""FFN servers don't need profiling."""
|
||||
pass
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, "KVCacheSpec"]:
|
||||
"""FFN servers don't use KV cache."""
|
||||
return {}
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: "KVCacheConfig") -> None:
|
||||
"""FFN servers don't use KV cache."""
|
||||
pass
|
||||
|
||||
def _dummy_run(self, num_tokens: int = 1, **kwargs) -> torch.Tensor:
|
||||
"""FFN servers don't need dummy runs."""
|
||||
# Return a dummy tensor for interface compatibility
|
||||
return torch.zeros(
|
||||
num_tokens,
|
||||
self.model_config.hf_config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def capture_model(self) -> int:
|
||||
"""Capture CUDA graphs for FFN operations."""
|
||||
if not self.use_cuda_graph:
|
||||
logger.warning("Skipping CUDA graph capture.")
|
||||
return 0
|
||||
|
||||
logger.info("Starting CUDA graph capture for FFN operations...")
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
# Create memory pool for graphs
|
||||
if self._graph_memory_pool is None:
|
||||
self._graph_memory_pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
# Capture graphs for each layer and different batch sizes
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
for layer_idx in range(self.num_layers):
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
with set_forward_context(
|
||||
attn_metadata=None, vllm_config=self.vllm_config
|
||||
):
|
||||
self._capture_graph_for_layer_and_size(layer_idx, num_tokens)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
|
||||
logger.info(
|
||||
"FFN CUDA graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time,
|
||||
cuda_graph_size / (1 << 30),
|
||||
)
|
||||
return cuda_graph_size
|
||||
|
||||
def _capture_graph_for_layer_and_size(self, layer_idx: int, num_tokens: int):
|
||||
"""Capture CUDA graph for specific layer and number of tokens."""
|
||||
# Create dummy hidden states
|
||||
dummy_hidden_states = torch.randn(
|
||||
num_tokens,
|
||||
self.model_config.hf_config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Warm up the operations for this specific layer
|
||||
for _ in range(self.vllm_config.compilation_config.cudagraph_num_of_warmups):
|
||||
self._run_ffn_computation(
|
||||
dummy_hidden_states, layer_idx=layer_idx, capture_mode=True
|
||||
)
|
||||
|
||||
# Create and capture the graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Start graph capture
|
||||
with torch.cuda.graph(graph, pool=self._graph_memory_pool):
|
||||
output = self._run_ffn_computation(
|
||||
dummy_hidden_states, layer_idx=layer_idx, capture_mode=True
|
||||
)
|
||||
|
||||
# Store the captured graph with layer and token count as key
|
||||
self._cuda_graphs[(layer_idx, num_tokens)] = {
|
||||
"graph": graph,
|
||||
"input_hidden_states": dummy_hidden_states,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Captured CUDA graph for layer %s with %s tokens", layer_idx, num_tokens
|
||||
)
|
||||
|
||||
def _run_ffn_computation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
layer_idx: int | None = None,
|
||||
capture_mode: bool = False,
|
||||
):
|
||||
"""Run FFN computation for graph capture or replay."""
|
||||
if layer_idx is None:
|
||||
current_layer_idx = self._get_current_layer_idx() if not capture_mode else 0
|
||||
else:
|
||||
current_layer_idx = layer_idx
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
if tp_world_size > 1:
|
||||
# Handle TP case: all-gather tensors from all TP ranks
|
||||
gathered_hidden_states = tensor_model_parallel_all_gather(
|
||||
hidden_states, dim=0
|
||||
)
|
||||
ffn_output = self.model.compute_ffn_output(
|
||||
gathered_hidden_states, current_layer_idx
|
||||
)
|
||||
|
||||
# Extract the output corresponding to current rank
|
||||
start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank()
|
||||
end_idx = start_idx + hidden_states.shape[0]
|
||||
rank_ffn_output = ffn_output[start_idx:end_idx, :]
|
||||
else:
|
||||
# Single TP case
|
||||
rank_ffn_output = self.model.compute_ffn_output(
|
||||
hidden_states, current_layer_idx
|
||||
)
|
||||
|
||||
return rank_ffn_output
|
||||
|
||||
def _find_cuda_graph(self, layer_idx: int, num_tokens: int):
|
||||
"""Find the smallest graph that can handle the given layer and
|
||||
number of tokens."""
|
||||
if not self.use_cuda_graph:
|
||||
return None
|
||||
|
||||
# Find the minimum capture size that can handle num_tokens for this
|
||||
# layer
|
||||
for capture_size in self.cudagraph_batch_sizes:
|
||||
if num_tokens <= capture_size:
|
||||
return self._cuda_graphs.get((layer_idx, capture_size))
|
||||
return None
|
||||
|
||||
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
"""FFN servers don't use samplers."""
|
||||
pass
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
"""Update configuration for FFN model runner."""
|
||||
allowed_config_names = {"load_config", "model_config"}
|
||||
for config_name, config_overrides in overrides.items():
|
||||
assert config_name in allowed_config_names, (
|
||||
f"Config `{config_name}` not supported. "
|
||||
f"Allowed configs: {allowed_config_names}"
|
||||
)
|
||||
config = getattr(self, config_name)
|
||||
from vllm.config import update_config
|
||||
|
||||
new_config = update_config(config, config_overrides)
|
||||
setattr(self, config_name, new_config)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
"""Reload model weights for FFN model runner."""
|
||||
assert getattr(self, "model", None) is not None, (
|
||||
"Cannot reload weights before model is loaded."
|
||||
)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
logger.info("Reloading weights inplace...")
|
||||
model = self.get_model()
|
||||
model_loader.load_weights(model, model_config=self.model_config)
|
||||
|
||||
@property
|
||||
def lora_config(self):
|
||||
"""FFN servers don't support LoRA."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_pooling_model(self) -> bool:
|
||||
"""FFN servers are not pooling models."""
|
||||
return False
|
||||
|
||||
def _dummy_pooler_run(self, hidden_states: torch.Tensor):
|
||||
"""FFN servers don't have poolers."""
|
||||
pass
|
||||
|
||||
def get_supported_tasks(self):
|
||||
"""Get supported tasks for FFN model runner."""
|
||||
return []
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
"""Get number of input tokens for FFN model runner."""
|
||||
return num_scheduled_tokens
|
||||
|
||||
def take_draft_token_ids(self, **kwargs):
|
||||
"""FFN servers don't support draft tokens."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def eplb_state(self):
|
||||
"""FFN servers don't have EPLB state."""
|
||||
return None
|
||||
|
||||
def ensure_kv_transfer_shutdown(self):
|
||||
"""FFN servers don't need KV transfer shutdown."""
|
||||
pass
|
||||
|
||||
def save_tensorized_model(
|
||||
self,
|
||||
tensorizer_config: "TensorizerConfig",
|
||||
) -> None:
|
||||
"""FFN servers don't support tensorized model saving."""
|
||||
raise NotImplementedError("FFN servers don't support tensorized model saving")
|
||||
@ -37,6 +37,7 @@ from vllm.config import (
|
||||
get_layers_from_vllm_config,
|
||||
update_config,
|
||||
)
|
||||
from vllm.distributed.afd_transfer import AFDConnectorFactory
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
||||
@ -45,11 +46,13 @@ from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
graph_capture,
|
||||
is_global_first_rank,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.forward_context import (
|
||||
AFDMetadata,
|
||||
BatchDescriptor,
|
||||
set_forward_context,
|
||||
)
|
||||
@ -547,6 +550,16 @@ class GPUModelRunner(
|
||||
# means this layer will perform attention using the keys and values
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||
|
||||
# init AFD config
|
||||
self.afd_config = vllm_config.afd_config
|
||||
if self.afd_config and self.afd_config.afd_role == "attention":
|
||||
self.afd_connector = AFDConnectorFactory.create_connector(
|
||||
get_world_group().rank, get_world_group().local_rank, vllm_config
|
||||
)
|
||||
self.afd_connector.init_afd_connector()
|
||||
self.num_stages = self.afd_config.num_afd_stages
|
||||
|
||||
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
|
||||
|
||||
self.kv_sharing_fast_prefill_logits_indices = None
|
||||
@ -606,6 +619,25 @@ class GPUModelRunner(
|
||||
self.kv_connector_output: KVConnectorOutput | None = None
|
||||
self.layerwise_nvtx_hooks_registered = False
|
||||
|
||||
profile_dir = (
|
||||
"./profiler_logs/attn"
|
||||
if self.afd_config is not None and self.afd_config.afd_role == "attention"
|
||||
else "./profiler_logs/normal"
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=6000 + 4000, warmup=1, active=30, repeat=1
|
||||
),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=False,
|
||||
with_stack=False,
|
||||
)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
@ -1303,7 +1335,7 @@ class GPUModelRunner(
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
logger.info(f"nums_reqs: {num_reqs}")
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit_block_table(num_reqs)
|
||||
@ -2877,6 +2909,7 @@ class GPUModelRunner(
|
||||
# decoder.
|
||||
allow_dp_padding = (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
or self.afd_config is not None
|
||||
)
|
||||
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
@ -2958,6 +2991,38 @@ class GPUModelRunner(
|
||||
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
|
||||
self.layerwise_nvtx_hooks_registered = True
|
||||
|
||||
def _build_afd_metadata(
|
||||
self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int
|
||||
):
|
||||
afd_metadata = None
|
||||
if self.afd_config:
|
||||
# For prefill, compute tokens per stage based on actual token
|
||||
# counts
|
||||
afd_tokens_start_loc = [0]
|
||||
afd_tokens_lens = []
|
||||
if ubatch_slices and len(ubatch_slices) > 1:
|
||||
afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices]
|
||||
afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices]
|
||||
logger.info(
|
||||
f"afd_tokens_start_loc: {afd_tokens_start_loc} "
|
||||
f"afd_reqs_start_loc: {afd_reqs_start_loc} "
|
||||
f"ubatch_slices: {ubatch_slices}"
|
||||
)
|
||||
afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices]
|
||||
else:
|
||||
afd_tokens_start_loc = [0]
|
||||
afd_reqs_start_loc = [0]
|
||||
afd_tokens_lens = [num_tokens_unpadded]
|
||||
afd_metadata = AFDMetadata(
|
||||
afd_tokens_start_loc=afd_tokens_start_loc,
|
||||
afd_reqs_start_loc=afd_reqs_start_loc,
|
||||
afd_stage_idx=0,
|
||||
afd_connector=self.afd_connector,
|
||||
afd_tokens_lens=afd_tokens_lens,
|
||||
num_of_stages=len(ubatch_slices) if ubatch_slices else 1,
|
||||
)
|
||||
return afd_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -3130,6 +3195,11 @@ class GPUModelRunner(
|
||||
# Mark KV scales as calculated after the first forward pass
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
afd_metadata = self._build_afd_metadata(
|
||||
ubatch_slices_padded, num_tokens_unpadded
|
||||
)
|
||||
|
||||
self.profiler.step()
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with (
|
||||
@ -3141,10 +3211,15 @@ class GPUModelRunner(
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
afd_metadata=afd_metadata,
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||
):
|
||||
if input_ids is not None:
|
||||
logger.info(f"input_ids: {input_ids.shape}")
|
||||
if inputs_embeds is not None:
|
||||
logger.info(f"inputs_embeds: {inputs_embeds.shape}")
|
||||
model_output = self._model_forward(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -4269,6 +4344,10 @@ class GPUModelRunner(
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[:] = num_tokens_padded
|
||||
|
||||
afd_metadata = self._build_afd_metadata(
|
||||
ubatch_slices_padded, num_tokens_unpadded
|
||||
)
|
||||
|
||||
with (
|
||||
self.maybe_randomize_inputs(input_ids, inputs_embeds),
|
||||
set_forward_context(
|
||||
@ -4279,6 +4358,7 @@ class GPUModelRunner(
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
afd_metadata=afd_metadata,
|
||||
),
|
||||
):
|
||||
outputs = self.model(
|
||||
@ -4814,7 +4894,6 @@ class GPUModelRunner(
|
||||
kv_cache_spec,
|
||||
kv_cache_group_id,
|
||||
)
|
||||
|
||||
attn_groups.append(attn_group)
|
||||
return attn_groups
|
||||
|
||||
@ -5491,6 +5570,11 @@ class GPUModelRunner(
|
||||
kv_transfer_group.register_kv_caches(kv_caches)
|
||||
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
def initialize_afd_connector(self) -> None:
|
||||
"""Initialize AFD connector if available."""
|
||||
if hasattr(self, "afd_connector") and self.afd_connector:
|
||||
self.afd_connector.init_afd_connector()
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import (
|
||||
AFDMetadata,
|
||||
DPMetadata,
|
||||
create_forward_context,
|
||||
get_forward_context,
|
||||
@ -126,7 +127,10 @@ class UBatchWrapper:
|
||||
comm_sms: int = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
if (
|
||||
vllm_config.parallel_config.enable_expert_parallel
|
||||
and not vllm_config.afd_config
|
||||
):
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
ep_group = get_ep_group()
|
||||
@ -303,6 +307,7 @@ class UBatchWrapper:
|
||||
dp_metadata,
|
||||
batch_descriptor,
|
||||
cudagraph_runtime_mode,
|
||||
afd_metadata,
|
||||
) -> list[UbatchMetadata]:
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
@ -314,6 +319,7 @@ class UBatchWrapper:
|
||||
dp_metadata=dp_metadata[i],
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -353,6 +359,62 @@ class UBatchWrapper:
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _make_afd_ubatch_metadata(
|
||||
self,
|
||||
ubatch_slices,
|
||||
attn_metadata,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
dp_metadata,
|
||||
afd_metadata,
|
||||
) -> AFDMetadata:
|
||||
if ubatch_slices is None:
|
||||
afd_metadata.input_ids_list.append(input_ids)
|
||||
afd_metadata.positions_list.append(positions)
|
||||
afd_metadata.inputs_embeds_list.append(inputs_embeds)
|
||||
afd_metadata.intermediate_tensors_list.append(intermediate_tensors)
|
||||
afd_metadata.attn_metadata_list.append(attn_metadata)
|
||||
afd_metadata.dp_metadata_list.append(dp_metadata)
|
||||
else:
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
(
|
||||
sliced_input_ids,
|
||||
sliced_positions,
|
||||
sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors,
|
||||
) = self._slice_model_inputs(
|
||||
ubatch_slice.token_slice,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
)
|
||||
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
ubatch_num_tokens_across_dp = torch.tensor(
|
||||
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
ubatch_dp_metadata = DPMetadata.make(
|
||||
self.vllm_config.parallel_config,
|
||||
ubatch_slice.num_tokens,
|
||||
ubatch_num_tokens_across_dp,
|
||||
)
|
||||
|
||||
afd_metadata.input_ids_list.append(sliced_input_ids)
|
||||
afd_metadata.positions_list.append(sliced_positions)
|
||||
afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds)
|
||||
afd_metadata.intermediate_tensors_list.append(
|
||||
sliced_intermediate_tensors
|
||||
)
|
||||
afd_metadata.attn_metadata_list.append(
|
||||
attn_metadata[i] if attn_metadata is not None else None
|
||||
)
|
||||
afd_metadata.dp_metadata_list.append(ubatch_dp_metadata)
|
||||
|
||||
return afd_metadata
|
||||
|
||||
def _slice_model_inputs(
|
||||
self,
|
||||
tokens_slice: slice,
|
||||
@ -385,6 +447,34 @@ class UBatchWrapper:
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
ubatch_slices = forward_context.ubatch_slices
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
afd_metadata = forward_context.afd_metadata
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
input_ids = kwargs["input_ids"]
|
||||
positions = kwargs["positions"]
|
||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||
inputs_embeds = kwargs["inputs_embeds"]
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
if self.vllm_config.afd_config:
|
||||
afd_metadata = self._make_afd_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
dp_metadata=dp_metadata,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
forward_context.afd_metadata = afd_metadata
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.NONE:
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
@ -394,6 +484,7 @@ class UBatchWrapper:
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the cudagraph wrapper will try to capture a cudagraph
|
||||
# for this shape during a normal run.
|
||||
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.cudagraphs:
|
||||
@ -405,18 +496,9 @@ class UBatchWrapper:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
num_tokens = (
|
||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
||||
) * 2
|
||||
input_ids = kwargs["input_ids"]
|
||||
positions = kwargs["positions"]
|
||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||
inputs_embeds = kwargs["inputs_embeds"]
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
ubatch_dp_metadata = []
|
||||
@ -448,6 +530,7 @@ class UBatchWrapper:
|
||||
dp_metadata=ubatch_dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
with self.sm_control:
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
@ -470,6 +553,7 @@ class UBatchWrapper:
|
||||
dp_metadata=ubatch_dp_metadata,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
with self.sm_control:
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
|
||||
@ -6,7 +6,7 @@ import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -52,6 +52,8 @@ from vllm.v1.outputs import (
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_ffn_model_runner import GPUFFNModelRunner
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
@ -248,8 +250,11 @@ class Worker(WorkerBase):
|
||||
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
|
||||
init_workspace_manager(self.device, num_ubatches)
|
||||
|
||||
self.model_runner: GPUModelRunner | GPUFFNModelRunner | GPUModelRunnerV2
|
||||
if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server:
|
||||
self.model_runner = GPUFFNModelRunner(self.vllm_config, self.device)
|
||||
# Construct the model runner
|
||||
if self.use_v2_model_runner:
|
||||
elif self.use_v2_model_runner:
|
||||
from vllm.v1.worker.gpu.model_runner import (
|
||||
GPUModelRunner as GPUModelRunnerV2,
|
||||
)
|
||||
@ -574,8 +579,16 @@ class Worker(WorkerBase):
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | None:
|
||||
self, scheduler_output: Optional["SchedulerOutput"] = None
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
# FFN server mode: direct execution without pipeline parallelism
|
||||
if self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server:
|
||||
return self.model_runner.execute_model(scheduler_output)
|
||||
|
||||
if scheduler_output is None:
|
||||
raise ValueError("scheduler_output is required in normal inference mode")
|
||||
|
||||
# Normal inference mode
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -682,6 +695,53 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def start_ffn_server_loop(self) -> None:
|
||||
"""Start FFN server loop for AFD FFN workers"""
|
||||
if not (
|
||||
self.vllm_config.afd_config and self.vllm_config.afd_config.is_ffn_server
|
||||
):
|
||||
return
|
||||
|
||||
self.model_runner.capture_model()
|
||||
self.model_runner.initialize_afd_connector()
|
||||
|
||||
if self.profiler:
|
||||
self.profiler.start()
|
||||
for _ in range(1000): # FIXME: hardcoded profiler iterations
|
||||
self.model_runner.execute_model(scheduler_output=None)
|
||||
torch.cuda.synchronize() # Ensure GPU operations complete
|
||||
self.profiler.stop()
|
||||
print(self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
|
||||
import threading
|
||||
|
||||
self._ffn_shutdown_event = threading.Event()
|
||||
|
||||
def ffn_worker_loop():
|
||||
# Set CUDA device for this thread (thread-local context)
|
||||
torch.cuda.set_device(self.device)
|
||||
logger.info("FFN worker loop started")
|
||||
|
||||
try:
|
||||
while not self._ffn_shutdown_event.is_set():
|
||||
# Execute FFN computation
|
||||
self.model_runner.execute_model(scheduler_output=None)
|
||||
except Exception as e:
|
||||
logger.error("FFN worker loop error: %s", e)
|
||||
raise
|
||||
|
||||
self._ffn_thread = threading.Thread(target=ffn_worker_loop, daemon=True)
|
||||
self._ffn_thread.start()
|
||||
logger.info("FFN server loop started in worker")
|
||||
|
||||
def stop_ffn_server_loop(self) -> None:
|
||||
"""Stop FFN server loop"""
|
||||
if hasattr(self, "_ffn_shutdown_event"):
|
||||
self._ffn_shutdown_event.set()
|
||||
if hasattr(self, "_ffn_thread"):
|
||||
self._ffn_thread.join(timeout=5)
|
||||
logger.info("FFN server loop stopped")
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user