From d306d01dd776ca19744c83a0875fb05a0da1bace Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 14:30:55 +0800 Subject: [PATCH] [Feat] adapt step3 text model --- .../afd_connector/p2p_connector.py | 310 ++++++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/step3_text.py | 242 +++++++++++--- vllm/model_executor/models/step3_vl.py | 10 + vllm/v1/worker/gpu_model_runner.py | 56 ++++ vllm/v1/worker/gpu_ubatch_wrapper.py | 4 + 6 files changed, 580 insertions(+), 44 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 6dc1f317b87d5..3c3359fae96f5 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -7,6 +7,316 @@ from datetime import timedelta import torch from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + init_afd_process_group, + init_model_parallel_group, +) +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self.config = config + self._initialized: bool = False + self._need_recv_metadata: bool = True + self._tensor_metadata_list: dict[int, TensorMetadata] = {} + self._current_afd_connector_metadata: AFDConnectorMetadata | None = None + if getattr(self.config.model_config.hf_config, "text_config", None) is not None: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.text_config.num_hidden_layers + ) + else: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.num_hidden_layers + ) + + self.recv_attn_output_counter: int = 0 + self.recv_ffn_output_counter: int = 0 + + def close(self) -> None: + """Close the connector and release resources.""" + # TODO: Implement proper resource clean up if needed. + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + world_rank = self.rank if role == "attention" else self.rank + attn_size + afd_pg = init_afd_process_group( + backend="nccl", + init_method=( + f"tcp://{self.config.afd_config.afd_host}" + f":{self.config.afd_config.afd_port}" + ), + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + + # Construct rank lists for sub groups. + # Each group contains one attention and one ffn rank. + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + assert len(ffn_ranks) == len(attn_ranks), ( + "ffn_ranks and attn_ranks must have the same length" + ) + default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = [attn_ranks[i], ffn_ranks[i]] + sub_group_ranks.append(ranks) + # Create two independent groups: + # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) + # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) + # The communication domain (rank range) is the same, but different group_name + # creates independent groups. + self.a2e_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="a2e", + ) + self.e2a_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="e2a", + ) + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def _build_tensor_metadata_list( + self, + tensor_metadata: TensorMetadata, + connector_metadata: AFDConnectorMetadata, + ) -> dict[int, TensorMetadata]: + tensor_metadata_list = {} + num_of_stages = connector_metadata.num_of_stages + for idx in range(num_of_stages): + if idx == 0: + tensor_metadata_list[0] = tensor_metadata + else: + new_size = list(tensor_metadata.size) + new_size[0] = connector_metadata.afd_tokens_lens[idx] + tensor_metadata_list[idx] = TensorMetadata( + tensor_metadata.device, + tensor_metadata.dtype, + torch.Size(new_size), + ) + return tensor_metadata_list + + def _send_metadata( + self, + metadata: AFDConnectorMetadata, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + + tensor_metadata = TensorMetadata( + hidden_states.device.type, hidden_states.dtype, hidden_states.size() + ) + metadata_tuple = (metadata, tensor_metadata) + process_group.send_object(metadata_tuple, dst=dst) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, metadata + ) + + def _recv_metadata( + self, + src: int, + process_group: GroupCoordinator, + ) -> None: + (self._current_afd_connector_metadata, tensor_metadata) = ( + process_group.recv_object(src=src) + ) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, self._current_afd_connector_metadata + ) + + def _send_hidden_states( + self, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + assert not hidden_states.is_cpu, "Hidden states must be on GPU" + torch.distributed.send( + hidden_states, + dst=process_group.ranks[dst], + group=process_group.device_group, + ) + + def _recv_hidden_states( + self, + src: int, + process_group: GroupCoordinator, + tensor_metadata: TensorMetadata, + ) -> tuple[torch.Tensor, list]: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return {}, [] + assert src < process_group.world_size, f"Invalid src rank ({src})" + + hidden_states = torch.empty( + tensor_metadata.size, + dtype=tensor_metadata.dtype, + device=tensor_metadata.device, + ) + torch.distributed.recv( + hidden_states, + src=process_group.ranks[src], + group=process_group.device_group, + ) + return hidden_states, [] + + # ------------------------------------------------------------------------- + # attn -> ffn + # ------------------------------------------------------------------------- + + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ) -> None: + """ + Called by ATTN side to send intermediate tensors + generated by ATTN instances to FFN. + """ + try: + dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size + if metadata.layer_idx == 0 and metadata.stage_idx == 0: + self._send_metadata(metadata, hidden_states, dst, self.a2e_group) + self._current_afd_connector_metadata = metadata + self._send_hidden_states(hidden_states, dst, self.a2e_group) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the ATTN side to receive MOE output intermediate tensors, + possibly dispatching from the receiver to other GPUs. + """ + src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size + stage_idx = ( + self.recv_ffn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.e2a_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self.recv_ffn_output_counter = ( + self.recv_ffn_output_counter + 1 + ) % self._current_afd_connector_metadata.num_of_stages + return hidden_states, self._current_afd_connector_metadata + + # ------------------------------------------------------------------------- + # ffn -> attn + # ------------------------------------------------------------------------- + + def send_ffn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """ + Called by FFN side to send intermediate tensors generated by FFN + instances back to the sender (should be the same GPU as source). + """ + dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size + self._send_hidden_states(hidden_states, dst, self.e2a_group) + self.recv_attn_output_counter += 1 + if ( + self.recv_attn_output_counter + % ( + self._current_afd_connector_metadata.num_of_stages + * self.num_hidden_layers + ) + == 0 + ): + self._need_recv_metadata = True + self.recv_attn_output_counter = 0 + + def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the FFN side to receive intermediate tensors from ATTN. + Handles receiving and possibly dispatching tensors. + """ + src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size + if self._need_recv_metadata: + self._recv_metadata(src, self.a2e_group) + self._need_recv_metadata = False + + stage_idx = ( + self.recv_attn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + layer_idx = ( + self.recv_attn_output_counter + // self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.a2e_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self._current_afd_connector_metadata.layer_idx = layer_idx + return hidden_states, self._current_afd_connector_metadata +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg + from vllm.config import VllmConfig from vllm.distributed.parallel_state import ( GroupCoordinator, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7938cff98c354..6b7842b00f54a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1670,7 +1670,7 @@ class DeepseekV2ForCausalLM( return hidden_states def compute_ffn_output( - self, current_layer_idx, hidden_states + self, hidden_states, current_layer_idx ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) return hidden_states diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 7077f1a22e8d7..c012c26f83afa 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -11,12 +11,16 @@ from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata, +) +from vllm.forward_context import AFDMetadata, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3TextConfig +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield from .interfaces import SupportsPP from .utils import ( @@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module): config: Step3TextConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + afd_config: AFDConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.afd_role = afd_config.afd_role if afd_config is not None else None - self.self_attn = Step3TextAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - norm_eps=config.rms_norm_eps, - max_position_embedding=config.max_position_embedding, - head_dim=config.head_dim, - share_q_dim=config.share_q_dim, - rope_parameters=config.rope_parameters, - prefix=f"{prefix}.self_attn", - ) - - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) - moe_layers_enum = getattr(config, "moe_layers_enum", None) - if moe_layers_enum is not None: - moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] - else: - # Default to 1dense. - moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] - - if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock( - config=config, quant_config=quant_config, prefix=f"{prefix}.moe" - ) - self.share_expert = Step3TextMLP( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = Step3TextAttention( hidden_size=self.hidden_size, - intermediate_size=config.share_expert_dim, - hidden_act="silu", + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.share_expert", + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", ) - self.use_moe = True - else: - self.mlp = Step3TextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.use_moe = False + + self.layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + + if self.afd_role is None or self.afd_role == "ffn": + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if self.layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.use_moe = True + else: + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.use_moe = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module): hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + if self.use_moe: share_output = self.share_expert(hidden_states) moe_output = self.moe(hidden_states) @@ -309,6 +322,25 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ): + pass + + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + logger.info(f"{type(hidden_states)=}") + return hidden_states + @support_torch_compile class Step3TextModel(nn.Module): @@ -317,6 +349,8 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + logger.info(f"{quant_config=}") + afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -336,6 +370,7 @@ class Step3TextModel(nn.Module): config=config, cache_config=cache_config, quant_config=quant_config, + afd_config=afd_config, prefix=prefix, ), prefix=f"{prefix}.layers", @@ -352,6 +387,51 @@ class Step3TextModel(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + recv_handle = None + logger.info(f"{__file__}: forward with afd called, may blocked here") + for layer in islice(self.layers, self.start_layer, self.end_layer): + afd_connector = afd_metadata.afd_connector + afd_metadata.afd_stage_idx = dbo_current_ubatch_id() + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + + if recv_handle is not None: + for work in recv_handle: + work.wait() + current_hidden, residual = layer(positions, hidden_states, residual) + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=afd_metadata.afd_stage_idx, + seq_len=current_hidden.shape[0], + dtype=current_hidden.dtype, + device=current_hidden.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(current_hidden, metadata) + + if dbo_enabled(): + dbo_yield() + + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + if recv_handle is not None: + for work in recv_handle: + work.wait() + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -370,8 +450,19 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata is not None: + hidden_states, residual = self.forward_with_afd( + hidden_states, + residual, + positions, + afd_metadata, + ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -384,6 +475,15 @@ class Step3TextModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, + hidden_states, + layer_idx, + ) -> torch.Tensor | IntermediateTensors: + logger.info(f"{type(self.layers)=}, {type(layer_idx)=}") + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class Step3TextForCausalLM(nn.Module, SupportsPP): def __init__( @@ -398,6 +498,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.config = config self.vllm_config = vllm_config + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: @@ -429,11 +534,20 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + logger.info(f"{__file__}: load_weights!") qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) ( @@ -466,6 +580,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -477,9 +592,17 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: + logger.info( + f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " + f"is_common: {self.is_common_weight(name)}" + ) + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if any( disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params @@ -498,6 +621,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): param_name, weight_name, shard_id = mapping if weight_name not in name: continue + + if self.afd_role is not None and self.afd_role == "attention": + continue + name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -521,12 +648,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): loaded_params.add(name) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue for ( param_name, weight_name, start_idx, end_idx, ) in qkv_params_mapping: + logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -552,3 +686,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + def is_moe_weight(self, name): + if ( + "shared_expert" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + return True + return False diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index e5038e56a2708..e16cb53e9f194 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.language_model.compute_ffn_output( + hidden_states, current_layer_idx + ) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 82dd80d267182..fff775a9f8241 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -642,6 +642,25 @@ class GPUModelRunner( with_stack=False, ) + profile_dir = ( + "./profiler_logs/attn" + if self.afd_config is not None and self.afd_config.afd_role == "attention" + else "./profiler_logs/normal" + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 + 4000, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2969,6 +2988,38 @@ class GPUModelRunner( ) return afd_metadata + def _build_afd_metadata( + self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int + ): + afd_metadata = None + if self.afd_config: + # For prefill, compute tokens per stage based on actual token + # counts + afd_tokens_start_loc = [0] + afd_tokens_lens = [] + if ubatch_slices and len(ubatch_slices) > 1: + afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] + afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] + logger.info( + f"afd_tokens_start_loc: {afd_tokens_start_loc} " + f"afd_reqs_start_loc: {afd_reqs_start_loc} " + f"ubatch_slices: {ubatch_slices}" + ) + afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] + else: + afd_tokens_start_loc = [0] + afd_reqs_start_loc = [0] + afd_tokens_lens = [num_tokens_unpadded] + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + num_of_stages=len(ubatch_slices) if ubatch_slices else 1, + ) + return afd_metadata + @torch.inference_mode() def execute_model( self, @@ -5517,6 +5568,11 @@ class GPUModelRunner( if hasattr(self, "afd_connector") and self.afd_connector: self.afd_connector.init_afd_connector() + def initialize_afd_connector(self) -> None: + """Initialize AFD connector if available.""" + if hasattr(self, "afd_connector") and self.afd_connector: + self.afd_connector.init_afd_connector() + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9e17c718c5513..8a1c9d90abbbf 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -127,6 +127,10 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None + if ( + vllm_config.parallel_config.enable_expert_parallel + and not vllm_config.afd_config + ): if ( vllm_config.parallel_config.enable_expert_parallel and not vllm_config.afd_config