[Chore] some log info

This commit is contained in:
i-yuanyukun 2025-12-19 16:02:07 +08:00
parent 6a8d35a9b6
commit 11d7d5bf59

View File

@ -303,6 +303,10 @@ class Step3TextDecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# query, key and positions must have the same number of tokens
# /model_executor/layers/rotary_embedding/base.py
# positions.shape=torch.Size([8192]), hidden_states.shape=torch.Size([4096, 3712]
logger.info(f"{positions.shape=}, {hidden_states.shape=}")
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
@ -408,7 +412,9 @@ class Step3TextModel(nn.Module):
if recv_handle is not None:
for work in recv_handle:
work.wait()
logger.info(f"Step3TextModel {layer.layer_idx=}: {hidden_states.shape=}, {positions.shape=}")
current_hidden, residual = layer(positions, hidden_states, residual)
logger.info(f"create attn metadata: {current_hidden.shape=}")
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=afd_metadata.afd_stage_idx,
@ -580,7 +586,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
logger.info(f"{params_dict.keys()=}")
# logger.info(f"{params_dict.keys()=}")
loaded_params: set[str] = set()
expert_params_mapping = [
@ -592,10 +598,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
logger.info(
f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, "
f"is_common: {self.is_common_weight(name)}"
)
# logger.info(
# f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, "
# f"is_common: {self.is_common_weight(name)}"
# )
if self.afd_role == "attention" and self.is_moe_weight(name):
continue
@ -660,7 +666,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
start_idx,
end_idx,
) in qkv_params_mapping:
logger.info(f"{weight_name=}, {name=}")
# logger.info(f"{weight_name=}, {name=}")
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)