[Bugfix] Use random hidden states in dummy sampler run (#18543)

Signed-off-by: Bowen Wang <abmfy@icloud.com>
This commit is contained in:
Bowen Wang 2025-05-22 06:48:56 -07:00 committed by GitHub
parent 71075029f2
commit 4e04eceb58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1721,6 +1721,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# The dummy hidden states may contain special values,
# like `inf` or `nan`.
# To avoid breaking the sampler, we use a random tensor here instead.
hidden_states = torch.rand_like(hidden_states)
logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0)