mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 13:23:14 +08:00
[BugFix] Fix incorrect preallocated sampled_token_ids tensor size (#28025)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
938772af03
commit
5a0a6dfd55
@ -524,7 +524,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
|
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
|
||||||
self.transfer_event = torch.cuda.Event()
|
self.transfer_event = torch.cuda.Event()
|
||||||
self.sampled_token_ids_pinned_cpu = torch.empty(
|
self.sampled_token_ids_pinned_cpu = torch.empty(
|
||||||
(self.max_model_len, 1),
|
(self.max_num_reqs, 1),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user