Revert "Update sampling_metadata.py (#21937)" (#22088)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-01 13:24:46 +01:00 committed by GitHub
parent 28b18cc741
commit 87c94bc879
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -539,37 +539,37 @@ class SamplingTensors:
temperatures_t = torch.tensor(
temperatures,
device="cpu",
dtype=torch.float32,
dtype=dtype,
pin_memory=pin_memory,
)
top_ps_t = torch.tensor(
top_ps,
device="cpu",
dtype=torch.float32,
dtype=dtype,
pin_memory=pin_memory,
)
min_ps_t = torch.tensor(
min_ps,
device="cpu",
dtype=torch.float32,
dtype=dtype,
pin_memory=pin_memory,
)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
dtype=torch.float32,
dtype=dtype,
pin_memory=pin_memory,
)
frequency_penalties_t = torch.tensor(
frequency_penalties,
device="cpu",
dtype=torch.float32,
dtype=dtype,
pin_memory=pin_memory,
)
repetition_penalties_t = torch.tensor(
repetition_penalties,
device="cpu",
dtype=torch.float32,
dtype=dtype,
pin_memory=pin_memory,
)
top_ks_t = torch.tensor(