Revert back to torch.equal over torch.allclose from #28819 (#29086)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić 2025-11-25 15:23:38 +01:00 committed by GitHub
parent 516c3f7847
commit 0231ce836a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1055,11 +1055,11 @@ class EagleProposer:
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
and torch.allclose(
# TODO: Offload to CPU for comparison to avoid extra GPU memory
# usage in CI testing environments with limited GPU memory
and torch.equal(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
rtol=1e-5,
atol=1e-7,
)
):
share_embeddings = True
@ -1105,8 +1105,11 @@ class EagleProposer:
hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor)
# TODO: Offload to CPU for comparison to avoid extra GPU memory
# usage in CI testing environments with limited GPU memory
and torch.equal(
target_language_model.lm_head.weight, self.model.lm_head.weight
target_language_model.lm_head.weight.cpu(),
self.model.lm_head.weight.cpu(),
)
):
share_lm_head = True