diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 0605facbfaffb..f85679400d1c6 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -60,7 +60,7 @@ class P2PAFDConnector(AFDConnectorBase): self.num_hidden_layers: int = ( self.config.model_config.hf_config.num_hidden_layers ) - + self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 self.dp_metadata_list: dict[int, DPMetadata] = {} @@ -139,11 +139,9 @@ class P2PAFDConnector(AFDConnectorBase): for idx in range(num_of_stages): if idx == 0: tensor_metadata_list[0] = tensor_metadata - logger.info(f"build tensor metadata: stage_{idx=}, size={tensor_metadata.size}") else: new_size = list(tensor_metadata.size) new_size[0] = connector_metadata.afd_tokens_lens[idx] - logger.info(f"build tensor metadata: stage_{idx=}, {new_size=}, {connector_metadata.afd_tokens_lens=}") tensor_metadata_list[idx] = TensorMetadata( tensor_metadata.device, tensor_metadata.dtype, @@ -167,7 +165,6 @@ class P2PAFDConnector(AFDConnectorBase): ) metadata_tuple = (metadata, tensor_metadata) process_group.send_object(metadata_tuple, dst=dst) - logger.info(f"_send_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, metadata ) @@ -180,23 +177,32 @@ class P2PAFDConnector(AFDConnectorBase): (self._current_afd_connector_metadata, tensor_metadata) = ( process_group.recv_object(src=src) ) - logger.info(f"_recv_metadata called build tensor metadata") self._tensor_metadata_list = self._build_tensor_metadata_list( tensor_metadata, self._current_afd_connector_metadata ) - logger.info(f"{self.config.parallel_config.data_parallel_size=}") if self.config.parallel_config.data_parallel_size > 1: - logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + logger.info( + "jcz recv_metadata num_of_stages:{}".format( + self._current_afd_connector_metadata.num_of_stages + ) + ) for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] - logger.info(f"{stage_idx=}, {num_tokens_per_ubatch=}") self.dp_metadata_list[stage_idx] = DPMetadata.make( self.config.parallel_config, num_tokens_per_ubatch, - torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, - device="cpu", dtype=torch.int32), + torch.tensor( + [num_tokens_per_ubatch] + * self.config.parallel_config.data_parallel_size, + device="cpu", + dtype=torch.int32, + ), ) - logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) + logger.info( + "jcz recv_metadata self.dp_metadata_list:{}".format( + self.dp_metadata_list + ) + ) def _send_hidden_states( self, @@ -229,7 +235,6 @@ class P2PAFDConnector(AFDConnectorBase): dtype=tensor_metadata.dtype, device=tensor_metadata.device, ) - # logger.info(f"{__file__}: p2p recv hidden states: {hidden_states.shape=}, {tensor_metadata.size=}") torch.distributed.recv( hidden_states, src=process_group.ranks[src], @@ -267,7 +272,6 @@ class P2PAFDConnector(AFDConnectorBase): self.recv_ffn_output_counter % self._current_afd_connector_metadata.num_of_stages ) - logger.info(f"{stage_idx=}") hidden_states, work_list = self._recv_hidden_states( src, self.e2a_group, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 75a38805ef5aa..a09bf0c8bd37f 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -303,10 +303,6 @@ class Step3TextDecoderLayer(nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - # query, key and positions must have the same number of tokens - # /model_executor/layers/rotary_embedding/base.py - # positions.shape=torch.Size([8192]), hidden_states.shape=torch.Size([4096, 3712]) - logger.info(f"{positions.shape=}, {hidden_states.shape=}") hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -342,7 +338,6 @@ class Step3TextDecoderLayer(nn.Module): hidden_states = share_output + moe_output else: hidden_states = self.mlp(hidden_states) - logger.info(f"{type(hidden_states)=}") return hidden_states @@ -353,7 +348,6 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - logger.info(f"{quant_config=}") afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -440,7 +434,6 @@ class Step3TextModel(nn.Module): ubatch_hidden_states[stage_i] = hidden_states ubatch_residual[stage_i] = residual - logger.info(f"create attn metadata:, {afd_metadata.afd_tokens_lens=}") metadata = AFDConnectorMetadata.create_attention_metadata( layer_idx=layer.layer_idx, stage_idx=stage_i, @@ -515,7 +508,6 @@ class Step3TextModel(nn.Module): hidden_states, layer_idx, ) -> torch.Tensor | IntermediateTensors: - logger.info(f"{type(self.layers)=}, {type(layer_idx)=}") hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) return hidden_states @@ -582,7 +574,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - logger.info(f"{__file__}: load_weights!") qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) ( @@ -615,7 +606,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) - # logger.info(f"{params_dict.keys()=}") loaded_params: set[str] = set() expert_params_mapping = [ @@ -627,10 +617,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - # logger.info( - # f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, " - # f"is_common: {self.is_common_weight(name)}" - # ) if self.afd_role == "attention" and self.is_moe_weight(name): continue @@ -695,7 +681,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): start_idx, end_idx, ) in qkv_params_mapping: - # logger.info(f"{weight_name=}, {name=}") if weight_name not in name: continue name = name.replace(weight_name, param_name)