[misc][distributed] fix pp missing layer condition (#6446)

This commit is contained in:
youkaichao 2024-07-15 10:32:35 -07:00 committed by GitHub
parent 64fdc08c72
commit 4cf256ae7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,7 +83,10 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
missing_layer_names.append(name)
# NOTE: the trailing dot is used to match the prefix of the layer.
# without the dot, we could match a layer that is not missing,
# e.g., 'encoder.layer.1' would match 'encoder.layer.11'
missing_layer_names.append(name + '.')
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
return missing_layer_names