From d306d01dd776ca19744c83a0875fb05a0da1bace Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 14:30:55 +0800 Subject: [PATCH 01/13] [Feat] adapt step3 text model --- .../afd_connector/p2p_connector.py | 310 ++++++++++++++++++ vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/step3_text.py | 242 +++++++++++--- vllm/model_executor/models/step3_vl.py | 10 + vllm/v1/worker/gpu_model_runner.py | 56 ++++ vllm/v1/worker/gpu_ubatch_wrapper.py | 4 + 6 files changed, 580 insertions(+), 44 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 6dc1f317b87d5..3c3359fae96f5 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -7,6 +7,316 @@ from datetime import timedelta import torch from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + init_afd_process_group, + init_model_parallel_group, +) +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self.config = config + self._initialized: bool = False + self._need_recv_metadata: bool = True + self._tensor_metadata_list: dict[int, TensorMetadata] = {} + self._current_afd_connector_metadata: AFDConnectorMetadata | None = None + if getattr(self.config.model_config.hf_config, "text_config", None) is not None: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.text_config.num_hidden_layers + ) + else: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.num_hidden_layers + ) + + self.recv_attn_output_counter: int = 0 + self.recv_ffn_output_counter: int = 0 + + def close(self) -> None: + """Close the connector and release resources.""" + # TODO: Implement proper resource clean up if needed. + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + world_rank = self.rank if role == "attention" else self.rank + attn_size + afd_pg = init_afd_process_group( + backend="nccl", + init_method=( + f"tcp://{self.config.afd_config.afd_host}" + f":{self.config.afd_config.afd_port}" + ), + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + + # Construct rank lists for sub groups. + # Each group contains one attention and one ffn rank. + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + assert len(ffn_ranks) == len(attn_ranks), ( + "ffn_ranks and attn_ranks must have the same length" + ) + default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = [attn_ranks[i], ffn_ranks[i]] + sub_group_ranks.append(ranks) + # Create two independent groups: + # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) + # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) + # The communication domain (rank range) is the same, but different group_name + # creates independent groups. + self.a2e_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="a2e", + ) + self.e2a_group = init_model_parallel_group( + sub_group_ranks, + self.local_rank, + backend="nccl", + group_name="e2a", + ) + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def _build_tensor_metadata_list( + self, + tensor_metadata: TensorMetadata, + connector_metadata: AFDConnectorMetadata, + ) -> dict[int, TensorMetadata]: + tensor_metadata_list = {} + num_of_stages = connector_metadata.num_of_stages + for idx in range(num_of_stages): + if idx == 0: + tensor_metadata_list[0] = tensor_metadata + else: + new_size = list(tensor_metadata.size) + new_size[0] = connector_metadata.afd_tokens_lens[idx] + tensor_metadata_list[idx] = TensorMetadata( + tensor_metadata.device, + tensor_metadata.dtype, + torch.Size(new_size), + ) + return tensor_metadata_list + + def _send_metadata( + self, + metadata: AFDConnectorMetadata, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + + tensor_metadata = TensorMetadata( + hidden_states.device.type, hidden_states.dtype, hidden_states.size() + ) + metadata_tuple = (metadata, tensor_metadata) + process_group.send_object(metadata_tuple, dst=dst) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, metadata + ) + + def _recv_metadata( + self, + src: int, + process_group: GroupCoordinator, + ) -> None: + (self._current_afd_connector_metadata, tensor_metadata) = ( + process_group.recv_object(src=src) + ) + self._tensor_metadata_list = self._build_tensor_metadata_list( + tensor_metadata, self._current_afd_connector_metadata + ) + + def _send_hidden_states( + self, + hidden_states: torch.Tensor, + dst: int, + process_group: GroupCoordinator, + ) -> None: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return [] + assert dst < process_group.world_size, f"Invalid dst rank ({dst})" + assert not hidden_states.is_cpu, "Hidden states must be on GPU" + torch.distributed.send( + hidden_states, + dst=process_group.ranks[dst], + group=process_group.device_group, + ) + + def _recv_hidden_states( + self, + src: int, + process_group: GroupCoordinator, + tensor_metadata: TensorMetadata, + ) -> tuple[torch.Tensor, list]: + if not torch.distributed.is_initialized() or process_group.world_size == 1: + return {}, [] + assert src < process_group.world_size, f"Invalid src rank ({src})" + + hidden_states = torch.empty( + tensor_metadata.size, + dtype=tensor_metadata.dtype, + device=tensor_metadata.device, + ) + torch.distributed.recv( + hidden_states, + src=process_group.ranks[src], + group=process_group.device_group, + ) + return hidden_states, [] + + # ------------------------------------------------------------------------- + # attn -> ffn + # ------------------------------------------------------------------------- + + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ) -> None: + """ + Called by ATTN side to send intermediate tensors + generated by ATTN instances to FFN. + """ + try: + dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size + if metadata.layer_idx == 0 and metadata.stage_idx == 0: + self._send_metadata(metadata, hidden_states, dst, self.a2e_group) + self._current_afd_connector_metadata = metadata + self._send_hidden_states(hidden_states, dst, self.a2e_group) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the ATTN side to receive MOE output intermediate tensors, + possibly dispatching from the receiver to other GPUs. + """ + src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size + stage_idx = ( + self.recv_ffn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.e2a_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self.recv_ffn_output_counter = ( + self.recv_ffn_output_counter + 1 + ) % self._current_afd_connector_metadata.num_of_stages + return hidden_states, self._current_afd_connector_metadata + + # ------------------------------------------------------------------------- + # ffn -> attn + # ------------------------------------------------------------------------- + + def send_ffn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """ + Called by FFN side to send intermediate tensors generated by FFN + instances back to the sender (should be the same GPU as source). + """ + dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size + self._send_hidden_states(hidden_states, dst, self.e2a_group) + self.recv_attn_output_counter += 1 + if ( + self.recv_attn_output_counter + % ( + self._current_afd_connector_metadata.num_of_stages + * self.num_hidden_layers + ) + == 0 + ): + self._need_recv_metadata = True + self.recv_attn_output_counter = 0 + + def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Called by the FFN side to receive intermediate tensors from ATTN. + Handles receiving and possibly dispatching tensors. + """ + src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size + if self._need_recv_metadata: + self._recv_metadata(src, self.a2e_group) + self._need_recv_metadata = False + + stage_idx = ( + self.recv_attn_output_counter + % self._current_afd_connector_metadata.num_of_stages + ) + layer_idx = ( + self.recv_attn_output_counter + // self._current_afd_connector_metadata.num_of_stages + ) + hidden_states, work_list = self._recv_hidden_states( + src, + self.a2e_group, + self._tensor_metadata_list[stage_idx], + ) + self._current_afd_connector_metadata.recv_handle_list = work_list + self._current_afd_connector_metadata.layer_idx = layer_idx + return hidden_states, self._current_afd_connector_metadata +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg + from vllm.config import VllmConfig from vllm.distributed.parallel_state import ( GroupCoordinator, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7938cff98c354..6b7842b00f54a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1670,7 +1670,7 @@ class DeepseekV2ForCausalLM( return hidden_states def compute_ffn_output( - self, current_layer_idx, hidden_states + self, hidden_states, current_layer_idx ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) return hidden_states diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 7077f1a22e8d7..c012c26f83afa 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -11,12 +11,16 @@ from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata, +) +from vllm.forward_context import AFDMetadata, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3TextConfig +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield from .interfaces import SupportsPP from .utils import ( @@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module): config: Step3TextConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + afd_config: AFDConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.afd_role = afd_config.afd_role if afd_config is not None else None - self.self_attn = Step3TextAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - norm_eps=config.rms_norm_eps, - max_position_embedding=config.max_position_embedding, - head_dim=config.head_dim, - share_q_dim=config.share_q_dim, - rope_parameters=config.rope_parameters, - prefix=f"{prefix}.self_attn", - ) - - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) - moe_layers_enum = getattr(config, "moe_layers_enum", None) - if moe_layers_enum is not None: - moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] - else: - # Default to 1dense. - moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] - - if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock( - config=config, quant_config=quant_config, prefix=f"{prefix}.moe" - ) - self.share_expert = Step3TextMLP( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = Step3TextAttention( hidden_size=self.hidden_size, - intermediate_size=config.share_expert_dim, - hidden_act="silu", + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.share_expert", + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", ) - self.use_moe = True - else: - self.mlp = Step3TextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.use_moe = False + + self.layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + + if self.afd_role is None or self.afd_role == "ffn": + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if self.layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.use_moe = True + else: + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.use_moe = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module): hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + if self.use_moe: share_output = self.share_expert(hidden_states) moe_output = self.moe(hidden_states) @@ -309,6 +322,25 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ): + pass + + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + logger.info(f"{type(hidden_states)=}") + return hidden_states + @support_torch_compile class Step3TextModel(nn.Module): @@ -317,6 +349,8 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + logger.info(f"{quant_config=}") + afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -336,6 +370,7 @@ class Step3TextModel(nn.Module): config=config, cache_config=cache_config, quant_config=quant_config, + afd_config=afd_config, prefix=prefix, ), prefix=f"{prefix}.layers", @@ -352,6 +387,51 @@ class Step3TextModel(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + recv_handle = None + logger.info(f"{__file__}: forward with afd called, may blocked here") + for layer in islice(self.layers, self.start_layer, self.end_layer): + afd_connector = afd_metadata.afd_connector + afd_metadata.afd_stage_idx = dbo_current_ubatch_id() + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + + if recv_handle is not None: + for work in recv_handle: + work.wait() + current_hidden, residual = layer(positions, hidden_states, residual) + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=afd_metadata.afd_stage_idx, + seq_len=current_hidden.shape[0], + dtype=current_hidden.dtype, + device=current_hidden.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(current_hidden, metadata) + + if dbo_enabled(): + dbo_yield() + + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + if recv_handle is not None: + for work in recv_handle: + work.wait() + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -370,8 +450,19 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata is not None: + hidden_states, residual = self.forward_with_afd( + hidden_states, + residual, + positions, + afd_metadata, + ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -384,6 +475,15 @@ class Step3TextModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, + hidden_states, + layer_idx, + ) -> torch.Tensor | IntermediateTensors: + logger.info(f"{type(self.layers)=}, {type(layer_idx)=}") + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class Step3TextForCausalLM(nn.Module, SupportsPP): def __init__( @@ -398,6 +498,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.config = config self.vllm_config = vllm_config + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: @@ -429,11 +534,20 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + logger.info(f"{__file__}: load_weights!") qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) ( @@ -466,6 +580,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -477,9 +592,17 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: + logger.info( + f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " + f"is_common: {self.is_common_weight(name)}" + ) + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if any( disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params @@ -498,6 +621,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): param_name, weight_name, shard_id = mapping if weight_name not in name: continue + + if self.afd_role is not None and self.afd_role == "attention": + continue + name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -521,12 +648,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): loaded_params.add(name) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue for ( param_name, weight_name, start_idx, end_idx, ) in qkv_params_mapping: + logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -552,3 +686,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + def is_moe_weight(self, name): + if ( + "shared_expert" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + return True + return False diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index e5038e56a2708..e16cb53e9f194 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.language_model.compute_ffn_output( + hidden_states, current_layer_idx + ) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 82dd80d267182..fff775a9f8241 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -642,6 +642,25 @@ class GPUModelRunner( with_stack=False, ) + profile_dir = ( + "./profiler_logs/attn" + if self.afd_config is not None and self.afd_config.afd_role == "attention" + else "./profiler_logs/normal" + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=6000 + 4000, warmup=1, active=30, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2969,6 +2988,38 @@ class GPUModelRunner( ) return afd_metadata + def _build_afd_metadata( + self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int + ): + afd_metadata = None + if self.afd_config: + # For prefill, compute tokens per stage based on actual token + # counts + afd_tokens_start_loc = [0] + afd_tokens_lens = [] + if ubatch_slices and len(ubatch_slices) > 1: + afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] + afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] + logger.info( + f"afd_tokens_start_loc: {afd_tokens_start_loc} " + f"afd_reqs_start_loc: {afd_reqs_start_loc} " + f"ubatch_slices: {ubatch_slices}" + ) + afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] + else: + afd_tokens_start_loc = [0] + afd_reqs_start_loc = [0] + afd_tokens_lens = [num_tokens_unpadded] + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + num_of_stages=len(ubatch_slices) if ubatch_slices else 1, + ) + return afd_metadata + @torch.inference_mode() def execute_model( self, @@ -5517,6 +5568,11 @@ class GPUModelRunner( if hasattr(self, "afd_connector") and self.afd_connector: self.afd_connector.init_afd_connector() + def initialize_afd_connector(self) -> None: + """Initialize AFD connector if available.""" + if hasattr(self, "afd_connector") and self.afd_connector: + self.afd_connector.init_afd_connector() + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9e17c718c5513..8a1c9d90abbbf 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -127,6 +127,10 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None + if ( + vllm_config.parallel_config.enable_expert_parallel + and not vllm_config.afd_config + ): if ( vllm_config.parallel_config.enable_expert_parallel and not vllm_config.afd_config From cd16bcff1ea7373d706c69abe8976609d9854aa8 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 15:56:20 +0800 Subject: [PATCH 02/13] [Chore] resolve some bugs due to merge --- vllm/v1/worker/gpu_ffn_model_runner.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/gpu_ubatch_wrapper.py | 4 ---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index cb08c9c05ae58..cd9940ef5e7a7 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -220,7 +220,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): hidden_states, dim=0 ) ffn_output = self.model.compute_ffn_output( - current_layer_idx, gathered_hidden_states + gathered_hidden_states, current_layer_idx ) # Extract the output corresponding to current rank start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() @@ -229,7 +229,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: # Single TP case rank_ffn_output = self.model.compute_ffn_output( - current_layer_idx, hidden_states + hidden_states, current_layer_idx ) return rank_ffn_output diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fff775a9f8241..ed6b9cc98f3a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3211,8 +3211,9 @@ class GPUModelRunner( record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): - logger.info(f"input_ids: {input_ids.shape}") - if inputs_embeds: + if input_ids is not None: + logger.info(f"input_ids: {input_ids.shape}") + if inputs_embeds is not None: logger.info(f"inputs_embeds: {inputs_embeds.shape}") model_output = self._model_forward( input_ids=input_ids, diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 8a1c9d90abbbf..9e17c718c5513 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -127,10 +127,6 @@ class UBatchWrapper: comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None - if ( - vllm_config.parallel_config.enable_expert_parallel - and not vllm_config.afd_config - ): if ( vllm_config.parallel_config.enable_expert_parallel and not vllm_config.afd_config From f74bb82909d5a12749fc9aeeb676b52bc7d767be Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 15:56:43 +0800 Subject: [PATCH 03/13] [Chore] code lint --- vllm/v1/worker/gpu_ffn_model_runner.py | 6 ++++-- vllm/v1/worker/gpu_model_runner.py | 8 ++++++-- vllm/v1/worker/gpu_ubatch_wrapper.py | 9 ++++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index cd9940ef5e7a7..3f85267d8536a 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -130,8 +130,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): try: hidden_states, recv_metadata = self.connector.recv_attn_output() - if hasattr(self.connector, 'dp_metadata_list'): - dp_metadata = self.connector.dp_metadata_list.get(recv_metadata.stage_idx, None) + if hasattr(self.connector, "dp_metadata_list"): + dp_metadata = self.connector.dp_metadata_list.get( + recv_metadata.stage_idx, None + ) else: dp_metadata = None current_layer_idx = recv_metadata.layer_idx diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ed6b9cc98f3a2..a09f292d98e14 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3192,7 +3192,9 @@ class GPUModelRunner( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False - afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) self.profiler.step() # Run the model. @@ -4326,7 +4328,9 @@ class GPUModelRunner( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded - afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9e17c718c5513..7a44b6cbd42b9 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -405,9 +405,12 @@ class UBatchWrapper: afd_metadata.input_ids_list.append(sliced_input_ids) afd_metadata.positions_list.append(sliced_positions) afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) - afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors) + afd_metadata.intermediate_tensors_list.append( + sliced_intermediate_tensors + ) afd_metadata.attn_metadata_list.append( - attn_metadata[i] if attn_metadata is not None else None) + attn_metadata[i] if attn_metadata is not None else None + ) afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) return afd_metadata @@ -481,7 +484,7 @@ class UBatchWrapper: # num_tokens, we don't have a non-ubatched one. Without this # check, the cudagraph wrapper will try to capture a cudagraph # for this shape during a normal run. - + if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None if batch_descriptor.num_tokens in self.cudagraphs: From 26ddfa299cc2835ab5ac9d0e6f3ddba3b85e5db0 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 17:02:39 +0800 Subject: [PATCH 04/13] [Chore] remove duplicate code --- vllm/v1/worker/gpu_model_runner.py | 56 ------------------------------ 1 file changed, 56 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a09f292d98e14..2409c9071f94b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -642,25 +642,6 @@ class GPUModelRunner( with_stack=False, ) - profile_dir = ( - "./profiler_logs/attn" - if self.afd_config is not None and self.afd_config.afd_role == "attention" - else "./profiler_logs/normal" - ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=6000 + 4000, warmup=1, active=30, repeat=1 - ), - on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir), - record_shapes=True, - profile_memory=False, - with_stack=False, - ) - def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2988,38 +2969,6 @@ class GPUModelRunner( ) return afd_metadata - def _build_afd_metadata( - self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int - ): - afd_metadata = None - if self.afd_config: - # For prefill, compute tokens per stage based on actual token - # counts - afd_tokens_start_loc = [0] - afd_tokens_lens = [] - if ubatch_slices and len(ubatch_slices) > 1: - afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices] - afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices] - logger.info( - f"afd_tokens_start_loc: {afd_tokens_start_loc} " - f"afd_reqs_start_loc: {afd_reqs_start_loc} " - f"ubatch_slices: {ubatch_slices}" - ) - afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices] - else: - afd_tokens_start_loc = [0] - afd_reqs_start_loc = [0] - afd_tokens_lens = [num_tokens_unpadded] - afd_metadata = AFDMetadata( - afd_tokens_start_loc=afd_tokens_start_loc, - afd_reqs_start_loc=afd_reqs_start_loc, - afd_stage_idx=0, - afd_connector=self.afd_connector, - afd_tokens_lens=afd_tokens_lens, - num_of_stages=len(ubatch_slices) if ubatch_slices else 1, - ) - return afd_metadata - @torch.inference_mode() def execute_model( self, @@ -5573,11 +5522,6 @@ class GPUModelRunner( if hasattr(self, "afd_connector") and self.afd_connector: self.afd_connector.init_afd_connector() - def initialize_afd_connector(self) -> None: - """Initialize AFD connector if available.""" - if hasattr(self, "afd_connector") and self.afd_connector: - self.afd_connector.init_afd_connector() - def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. From 8276320a8a691adf8747d4d93b12e4c2de208e2d Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 17:03:15 +0800 Subject: [PATCH 05/13] [Bugfix] compute ffn output param order --- vllm/v1/worker/gpu_ffn_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index 3f85267d8536a..921824f408552 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -351,7 +351,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): hidden_states, dim=0 ) ffn_output = self.model.compute_ffn_output( - current_layer_idx, gathered_hidden_states + gathered_hidden_states, current_layer_idx ) # Extract the output corresponding to current rank @@ -361,7 +361,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: # Single TP case rank_ffn_output = self.model.compute_ffn_output( - current_layer_idx, hidden_states + hidden_states, current_layer_idx ) return rank_ffn_output From 6a8d35a9b6c36136c56c470e3bfbfa51a2c44c15 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Thu, 18 Dec 2025 17:32:42 +0800 Subject: [PATCH 06/13] [Chore] remove p2p connector duplicate code --- .../afd_connector/p2p_connector.py | 321 ------------------ 1 file changed, 321 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 3c3359fae96f5..dcb781c55540e 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -308,324 +308,3 @@ class P2PAFDConnector(AFDConnectorBase): self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx return hidden_states, self._current_afd_connector_metadata -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import re -from datetime import timedelta - -import torch -from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg - -from vllm.config import VllmConfig -from vllm.distributed.parallel_state import ( - GroupCoordinator, - TensorMetadata, - init_afd_process_group, - init_model_parallel_group, -) -from vllm.logger import init_logger -from vllm.forward_context import ( - DPMetadata, - get_forward_context, -) - -from .base import AFDConnectorBase -from .metadata import AFDConnectorMetadata - -logger = init_logger(__name__) - - -class DefaultProcessGroupSwitcher: - def __init__(self, default_group, new_default_group): - self.default_group = default_group - self.new_default_group = new_default_group - - def __enter__(self): - _update_default_pg(self.new_default_group) - - def __exit__(self, exc_type, exc_value, traceback): - _update_default_pg(self.default_group) - - -class P2PAFDConnector(AFDConnectorBase): - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ) -> None: - self.rank = rank - self.local_rank = local_rank - self.config = config - self._initialized: bool = False - self._need_recv_metadata: bool = True - self._tensor_metadata_list: dict[int, TensorMetadata] = {} - self._current_afd_connector_metadata: AFDConnectorMetadata | None = None - self.num_hidden_layers: int = ( - self.config.model_config.hf_config.num_hidden_layers - ) - self.recv_attn_output_counter: int = 0 - self.recv_ffn_output_counter: int = 0 - self.dp_metadata_list: dict[int, DPMetadata] = {} - - def close(self) -> None: - """Close the connector and release resources.""" - # TODO: Implement proper resource clean up if needed. - pass - - def init_afd_connector(self) -> None: - """Initialize the AFD connector.""" - afd_size = self.config.afd_config.afd_extra_config.get("afd_size") - role = self.config.afd_config.afd_role - attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) - world_rank = self.rank if role == "attention" else self.rank + attn_size - afd_pg = init_afd_process_group( - backend="nccl", - init_method=( - f"tcp://{self.config.afd_config.afd_host}" - f":{self.config.afd_config.afd_port}" - ), - world_size=ffn_size + attn_size, - rank=world_rank, - group_name="afd", - timeout=timedelta(minutes=2), - ) - - # Construct rank lists for sub groups. - # Each group contains one attention and one ffn rank. - ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] - attn_ranks = [i for i in range(attn_size)] - assert len(ffn_ranks) == len(attn_ranks), ( - "ffn_ranks and attn_ranks must have the same length" - ) - default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg) - with default_pg_switcher: - sub_group_ranks = [] - for i in range(len(ffn_ranks)): - ranks = [attn_ranks[i], ffn_ranks[i]] - sub_group_ranks.append(ranks) - # Create two independent groups: - # a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn) - # e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn) - # The communication domain (rank range) is the same, but different group_name - # creates independent groups. - self.a2e_group = init_model_parallel_group( - sub_group_ranks, - self.local_rank, - backend="nccl", - group_name="a2e", - ) - self.e2a_group = init_model_parallel_group( - sub_group_ranks, - self.local_rank, - backend="nccl", - group_name="e2a", - ) - - self._initialized = True - - def is_initialized(self) -> bool: - """Check if the connector is initialized and ready to use. - - Returns: - bool: True if the connector is initialized, False otherwise. - """ - return self._initialized - - def _build_tensor_metadata_list( - self, - tensor_metadata: TensorMetadata, - connector_metadata: AFDConnectorMetadata, - ) -> dict[int, TensorMetadata]: - tensor_metadata_list = {} - num_of_stages = connector_metadata.num_of_stages - for idx in range(num_of_stages): - if idx == 0: - tensor_metadata_list[0] = tensor_metadata - else: - new_size = list(tensor_metadata.size) - new_size[0] = connector_metadata.afd_tokens_lens[idx] - tensor_metadata_list[idx] = TensorMetadata( - tensor_metadata.device, - tensor_metadata.dtype, - torch.Size(new_size), - ) - return tensor_metadata_list - - def _send_metadata( - self, - metadata: AFDConnectorMetadata, - hidden_states: torch.Tensor, - dst: int, - process_group: GroupCoordinator, - ) -> None: - if not torch.distributed.is_initialized() or process_group.world_size == 1: - return [] - assert dst < process_group.world_size, f"Invalid dst rank ({dst})" - - tensor_metadata = TensorMetadata( - hidden_states.device.type, hidden_states.dtype, hidden_states.size() - ) - metadata_tuple = (metadata, tensor_metadata) - process_group.send_object(metadata_tuple, dst=dst) - self._tensor_metadata_list = self._build_tensor_metadata_list( - tensor_metadata, metadata - ) - - def _recv_metadata( - self, - src: int, - process_group: GroupCoordinator, - ) -> None: - (self._current_afd_connector_metadata, tensor_metadata) = ( - process_group.recv_object(src=src) - ) - self._tensor_metadata_list = self._build_tensor_metadata_list( - tensor_metadata, self._current_afd_connector_metadata - ) - if self.config.parallel_config.data_parallel_size > 1: - logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) - for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): - num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] - self.dp_metadata_list[stage_idx] = DPMetadata.make( - self.config.parallel_config, - num_tokens_per_ubatch, - torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, - device="cpu", dtype=torch.int32), - ) - logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) - - def _send_hidden_states( - self, - hidden_states: torch.Tensor, - dst: int, - process_group: GroupCoordinator, - ) -> None: - if not torch.distributed.is_initialized() or process_group.world_size == 1: - return [] - assert dst < process_group.world_size, f"Invalid dst rank ({dst})" - assert not hidden_states.is_cpu, "Hidden states must be on GPU" - torch.distributed.send( - hidden_states, - dst=process_group.ranks[dst], - group=process_group.device_group, - ) - - def _recv_hidden_states( - self, - src: int, - process_group: GroupCoordinator, - tensor_metadata: TensorMetadata, - ) -> tuple[torch.Tensor, list]: - if not torch.distributed.is_initialized() or process_group.world_size == 1: - return {}, [] - assert src < process_group.world_size, f"Invalid src rank ({src})" - - hidden_states = torch.empty( - tensor_metadata.size, - dtype=tensor_metadata.dtype, - device=tensor_metadata.device, - ) - torch.distributed.recv( - hidden_states, - src=process_group.ranks[src], - group=process_group.device_group, - ) - return hidden_states, [] - - # ------------------------------------------------------------------------- - # attn -> ffn - # ------------------------------------------------------------------------- - - def send_attn_output( - self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata - ) -> None: - """ - Called by ATTN side to send intermediate tensors - generated by ATTN instances to FFN. - """ - try: - dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size - if metadata.layer_idx == 0 and metadata.stage_idx == 0: - self._send_metadata(metadata, hidden_states, dst, self.a2e_group) - self._current_afd_connector_metadata = metadata - self._send_hidden_states(hidden_states, dst, self.a2e_group) - except Exception as e: - raise RuntimeError(f"Communication error: {e}") - - def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: - """ - Called by the ATTN side to receive MOE output intermediate tensors, - possibly dispatching from the receiver to other GPUs. - """ - src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size - stage_idx = ( - self.recv_ffn_output_counter - % self._current_afd_connector_metadata.num_of_stages - ) - hidden_states, work_list = self._recv_hidden_states( - src, - self.e2a_group, - self._tensor_metadata_list[stage_idx], - ) - self._current_afd_connector_metadata.recv_handle_list = work_list - self.recv_ffn_output_counter = ( - self.recv_ffn_output_counter + 1 - ) % self._current_afd_connector_metadata.num_of_stages - return hidden_states, self._current_afd_connector_metadata - - # ------------------------------------------------------------------------- - # ffn -> attn - # ------------------------------------------------------------------------- - - def send_ffn_output( - self, - hidden_states: torch.Tensor, - metadata: AFDConnectorMetadata, - ) -> None: - """ - Called by FFN side to send intermediate tensors generated by FFN - instances back to the sender (should be the same GPU as source). - """ - dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size - self._send_hidden_states(hidden_states, dst, self.e2a_group) - self.recv_attn_output_counter += 1 - if ( - self.recv_attn_output_counter - % ( - self._current_afd_connector_metadata.num_of_stages - * self.num_hidden_layers - ) - == 0 - ): - self._need_recv_metadata = True - self.recv_attn_output_counter = 0 - - def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]: - """ - Called by the FFN side to receive intermediate tensors from ATTN. - Handles receiving and possibly dispatching tensors. - """ - src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size - if self._need_recv_metadata: - self._recv_metadata(src, self.a2e_group) - self._need_recv_metadata = False - - stage_idx = ( - self.recv_attn_output_counter - % self._current_afd_connector_metadata.num_of_stages - ) - layer_idx = ( - self.recv_attn_output_counter - // self._current_afd_connector_metadata.num_of_stages - ) - hidden_states, work_list = self._recv_hidden_states( - src, - self.a2e_group, - self._tensor_metadata_list[stage_idx], - ) - self._current_afd_connector_metadata.recv_handle_list = work_list - self._current_afd_connector_metadata.layer_idx = layer_idx - self._current_afd_connector_metadata.stage_idx = stage_idx - return hidden_states, self._current_afd_connector_metadata From 11d7d5bf594c3ef743fef4b496705d5e39290034 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:02:07 +0800 Subject: [PATCH 07/13] [Chore] some log info --- vllm/model_executor/models/step3_text.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index c012c26f83afa..6355573a68774 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -303,6 +303,10 @@ class Step3TextDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) + # query, key and positions must have the same number of tokens + # /model_executor/layers/rotary_embedding/base.py + # positions.shape=torch.Size([8192]), hidden_states.shape=torch.Size([4096, 3712]) + logger.info(f"{positions.shape=}, {hidden_states.shape=}") hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -408,7 +412,9 @@ class Step3TextModel(nn.Module): if recv_handle is not None: for work in recv_handle: work.wait() + logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}") current_hidden, residual = layer(positions, hidden_states, residual) + logger.info(f"create attn metadata: {current_hidden.shape=}") metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=afd_metadata.afd_stage_idx, @@ -580,7 +586,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - logger.info(f"{params_dict.keys()=}") + # logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -592,10 +598,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - logger.info( - f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " - f"is_common: {self.is_common_weight(name)}" - ) + # logger.info( + # f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " + # f"is_common: {self.is_common_weight(name)}" + # ) if self.afd_role == "attention" and self.is_moe_weight(name): continue @@ -660,7 +666,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): start_idx, end_idx, ) in qkv_params_mapping: - logger.info(f"{weight_name=}, {name=}") + # logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) From 65ea10c8f453e5294370b5f98d08280f49936602 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:03:47 +0800 Subject: [PATCH 08/13] [Chore] bring back deleted code --- .../afd_connector/p2p_connector.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index dcb781c55540e..58be99bf117ea 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -15,6 +15,10 @@ from vllm.distributed.parallel_state import ( init_model_parallel_group, ) from vllm.logger import init_logger +from vllm.forward_context import ( + DPMetadata, + get_forward_context, +) from .base import AFDConnectorBase from .metadata import AFDConnectorMetadata @@ -59,6 +63,7 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 + self.dp_metadata_list: dict[int, DPMetadata] = {} def close(self) -> None: """Close the connector and release resources.""" @@ -175,6 +180,19 @@ class P2PAFDConnector(AFDConnectorBase): self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) + logger.info(f"{self.config.parallel_config.data_parallel_size=}") + if self.config.parallel_config.data_parallel_size > 1: + logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): + num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] + logger.info(f"{stage_idx=}, {num_tokens_per_ubatch=}") + self.dp_metadata_list[stage_idx] = DPMetadata.make( + self.config.parallel_config, + num_tokens_per_ubatch, + torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, + device="cpu", dtype=torch.int32), + ) + logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) def _send_hidden_states( self, @@ -307,4 +325,5 @@ class P2PAFDConnector(AFDConnectorBase): ) self._current_afd_connector_metadata.recv_handle_list = work_list self._current_afd_connector_metadata.layer_idx = layer_idx + self._current_afd_connector_metadata.stage_idx = stage_idx return hidden_states, self._current_afd_connector_metadata From bde36017fa2324f26e11f65602125df282899fd5 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:04:47 +0800 Subject: [PATCH 09/13] [Chore] adjust log info --- vllm/v1/worker/gpu_ffn_model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index 921824f408552..79ba7e6eb5c2e 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -137,10 +137,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: dp_metadata = None current_layer_idx = recv_metadata.layer_idx - logger.info( - f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" - f" dp_metadata: {dp_metadata}" - ) + # logger.info( + # f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" + # f" dp_metadata: {dp_metadata}" + # ) num_tokens = hidden_states.shape[0] if recv_metadata is not None and recv_metadata.recv_handle_list is not None: for work in recv_metadata.recv_handle_list: From 6d305dda383ee6107125a7605fcefbd7d80e84c7 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Fri, 19 Dec 2025 16:11:47 +0800 Subject: [PATCH 10/13] [Chore] add p2p connector debug log info --- .../distributed/afd_transfer/afd_connector/p2p_connector.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 58be99bf117ea..0605facbfaffb 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -139,9 +139,11 @@ class P2PAFDConnector(AFDConnectorBase): for idx in range(num_of_stages): if idx == 0: tensor_metadata_list[0] = tensor_metadata + logger.info(f"build tensor metadata: stage_{idx=}, size={tensor_metadata.size}") else: new_size = list(tensor_metadata.size) new_size[0] = connector_metadata.afd_tokens_lens[idx] + logger.info(f"build tensor metadata: stage_{idx=}, {new_size=}, {connector_metadata.afd_tokens_lens=}") tensor_metadata_list[idx] = TensorMetadata( tensor_metadata.device, tensor_metadata.dtype, @@ -165,6 +167,7 @@ class P2PAFDConnector(AFDConnectorBase): ) metadata_tuple = (metadata, tensor_metadata) process_group.send_object(metadata_tuple, dst=dst) + logger.info(f"_send_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, metadata ) @@ -177,6 +180,7 @@ class P2PAFDConnector(AFDConnectorBase): (self._current_afd_connector_metadata, tensor_metadata) = ( process_group.recv_object(src=src) ) + logger.info(f"_recv_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) @@ -225,6 +229,7 @@ class P2PAFDConnector(AFDConnectorBase): dtype=tensor_metadata.dtype, device=tensor_metadata.device, ) + # logger.info(f"{__file__}: p2p recv hidden states: {hidden_states.shape=}, {tensor_metadata.size=}") torch.distributed.recv( hidden_states, src=process_group.ranks[src], @@ -262,6 +267,7 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) + logger.info(f"{stage_idx=}") hidden_states, work_list = self._recv_hidden_states( src, self.e2a_group, From 2a98ab3c8ea831c6a202225644edd671051e43ec Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Mon, 22 Dec 2025 14:29:03 +0800 Subject: [PATCH 11/13] [Chore]: step3 forward_with_afd --- vllm/model_executor/models/step3_text.py | 91 ++++++++++++++++-------- 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 6355573a68774..75a38805ef5aa 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -398,43 +398,72 @@ class Step3TextModel(nn.Module): positions: torch.Tensor, afd_metadata: AFDMetadata, ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() recv_handle = None - logger.info(f"{__file__}: forward with afd called, may blocked here") - for layer in islice(self.layers, self.start_layer, self.end_layer): - afd_connector = afd_metadata.afd_connector - afd_metadata.afd_stage_idx = dbo_current_ubatch_id() - if layer.layer_idx > 0: - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list + ubatch_hidden_states = [] + ubatch_residual = [] - if recv_handle is not None: - for work in recv_handle: - work.wait() - logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}") - current_hidden, residual = layer(positions, hidden_states, residual) - logger.info(f"create attn metadata: {current_hidden.shape=}") - metadata = AFDConnectorMetadata.create_attention_metadata( - layer_idx=layer.layer_idx, - stage_idx=afd_metadata.afd_stage_idx, - seq_len=current_hidden.shape[0], - dtype=current_hidden.dtype, - device=current_hidden.device, - num_of_stages=afd_metadata.num_of_stages, - afd_tokens_lens=afd_metadata.afd_tokens_lens, + start_idx = 0 + for pos in afd_metadata.positions_list: + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None ) - afd_connector.send_attn_output(current_hidden, metadata) + start_idx = end_idx - if dbo_enabled(): - dbo_yield() + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] - hidden_states, recv_metadata = afd_connector.recv_ffn_output() - if recv_metadata.recv_handle_list is not None: - recv_handle = recv_metadata.recv_handle_list - if recv_handle is not None: - for work in recv_handle: - work.wait() + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + else: + hidden_states = ubatch_hidden_states[stage_i] + + if recv_handle is not None: + for work in recv_handle: + work.wait() + + current_positions = afd_metadata.positions_list[stage_i] + hidden_states, residual = layer( + current_positions, hidden_states, residual + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}") + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i], recv_metadata = ( + afd_connector.recv_ffn_output() + ) + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None return hidden_states, residual From 27ae2e761c17e692a294ae1ce4a37c1092f270d7 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Mon, 22 Dec 2025 15:25:28 +0800 Subject: [PATCH 12/13] [Chore] clean up debug info --- .../afd_connector/p2p_connector.py | 30 +++++++++++-------- vllm/model_executor/models/step3_text.py | 15 ---------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 0605facbfaffb..f85679400d1c6 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -60,7 +60,7 @@ class P2PAFDConnector(AFDConnectorBase): self.num_hidden_layers: int = ( self.config.model_config.hf_config.num_hidden_layers ) - + self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 self.dp_metadata_list: dict[int, DPMetadata] = {} @@ -139,11 +139,9 @@ class P2PAFDConnector(AFDConnectorBase): for idx in range(num_of_stages): if idx == 0: tensor_metadata_list[0] = tensor_metadata - logger.info(f"build tensor metadata: stage_{idx=}, size={tensor_metadata.size}") else: new_size = list(tensor_metadata.size) new_size[0] = connector_metadata.afd_tokens_lens[idx] - logger.info(f"build tensor metadata: stage_{idx=}, {new_size=}, {connector_metadata.afd_tokens_lens=}") tensor_metadata_list[idx] = TensorMetadata( tensor_metadata.device, tensor_metadata.dtype, @@ -167,7 +165,6 @@ class P2PAFDConnector(AFDConnectorBase): ) metadata_tuple = (metadata, tensor_metadata) process_group.send_object(metadata_tuple, dst=dst) - logger.info(f"_send_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, metadata ) @@ -180,23 +177,32 @@ class P2PAFDConnector(AFDConnectorBase): (self._current_afd_connector_metadata, tensor_metadata) = ( process_group.recv_object(src=src) ) - logger.info(f"_recv_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) - logger.info(f"{self.config.parallel_config.data_parallel_size=}") if self.config.parallel_config.data_parallel_size > 1: - logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + logger.info( + "jcz recv_metadata num_of_stages:{}".format( + self._current_afd_connector_metadata.num_of_stages + ) + ) for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] - logger.info(f"{stage_idx=}, {num_tokens_per_ubatch=}") self.dp_metadata_list[stage_idx] = DPMetadata.make( self.config.parallel_config, num_tokens_per_ubatch, - torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, - device="cpu", dtype=torch.int32), + torch.tensor( + [num_tokens_per_ubatch] + * self.config.parallel_config.data_parallel_size, + device="cpu", + dtype=torch.int32, + ), ) - logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) + logger.info( + "jcz recv_metadata self.dp_metadata_list:{}".format( + self.dp_metadata_list + ) + ) def _send_hidden_states( self, @@ -229,7 +235,6 @@ class P2PAFDConnector(AFDConnectorBase): dtype=tensor_metadata.dtype, device=tensor_metadata.device, ) - # logger.info(f"{__file__}: p2p recv hidden states: {hidden_states.shape=}, {tensor_metadata.size=}") torch.distributed.recv( hidden_states, src=process_group.ranks[src], @@ -267,7 +272,6 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) - logger.info(f"{stage_idx=}") hidden_states, work_list = self._recv_hidden_states( src, self.e2a_group, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 75a38805ef5aa..a09bf0c8bd37f 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -303,10 +303,6 @@ class Step3TextDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - # query, key and positions must have the same number of tokens - # /model_executor/layers/rotary_embedding/base.py - # positions.shape=torch.Size([8192]), hidden_states.shape=torch.Size([4096, 3712]) - logger.info(f"{positions.shape=}, {hidden_states.shape=}") hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -342,7 +338,6 @@ class Step3TextDecoderLayer(nn.Module): hidden_states = share_output + moe_output else: hidden_states = self.mlp(hidden_states) - logger.info(f"{type(hidden_states)=}") return hidden_states @@ -353,7 +348,6 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - logger.info(f"{quant_config=}") afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -440,7 +434,6 @@ class Step3TextModel(nn.Module): ubatch_hidden_states[stage_i] = hidden_states ubatch_residual[stage_i] = residual - logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}") metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=stage_i, @@ -515,7 +508,6 @@ class Step3TextModel(nn.Module): hidden_states, layer_idx, ) -> torch.Tensor | IntermediateTensors: - logger.info(f"{type(self.layers)=}, {type(layer_idx)=}") hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) return hidden_states @@ -582,7 +574,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - logger.info(f"{__file__}: load_weights!") qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) ( @@ -615,7 +606,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - # logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -627,10 +617,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - # logger.info( - # f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " - # f"is_common: {self.is_common_weight(name)}" - # ) if self.afd_role == "attention" and self.is_moe_weight(name): continue @@ -695,7 +681,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): start_idx, end_idx, ) in qkv_params_mapping: - # logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name) From 60d65cdf5c7a2c56af5a8d9f2f2193b85b4dc520 Mon Sep 17 00:00:00 2001 From: i-yuanyukun Date: Mon, 22 Dec 2025 15:37:07 +0800 Subject: [PATCH 13/13] [Chore] remove unused method --- vllm/model_executor/models/step3_text.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index a09bf0c8bd37f..eb723582a0ccb 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -322,14 +322,6 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual - def compute_attn_output( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - ): - pass - def compute_ffn_output(self, hidden_states): assert self.afd_role == "ffn" if self.use_moe: