diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 784ccbc04932f..7b9037c03d4f0 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1055,11 +1055,11 @@ class EagleProposer: elif ( isinstance(target_embed_tokens.weight, torch.Tensor) and isinstance(self.model.model.embed_tokens.weight, torch.Tensor) - and torch.allclose( + # TODO: Offload to CPU for comparison to avoid extra GPU memory + # usage in CI testing environments with limited GPU memory + and torch.equal( target_embed_tokens.weight.cpu(), self.model.model.embed_tokens.weight.cpu(), - rtol=1e-5, - atol=1e-7, ) ): share_embeddings = True @@ -1105,8 +1105,11 @@ class EagleProposer: hasattr(target_language_model, "lm_head") and isinstance(target_language_model.lm_head.weight, torch.Tensor) and isinstance(self.model.lm_head.weight, torch.Tensor) + # TODO: Offload to CPU for comparison to avoid extra GPU memory + # usage in CI testing environments with limited GPU memory and torch.equal( - target_language_model.lm_head.weight, self.model.lm_head.weight + target_language_model.lm_head.weight.cpu(), + self.model.lm_head.weight.cpu(), ) ): share_lm_head = True