mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 16:22:19 +08:00
[CI][Bugfix] Fix failing Blackwell test (#24993)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
dbebb7f812
commit
d119fc8614
@ -506,12 +506,9 @@ class SharedResizableBuffer:
|
|||||||
def get(self, shape: tuple[int, ...], device: torch.device,
|
def get(self, shape: tuple[int, ...], device: torch.device,
|
||||||
dtype: torch.dtype):
|
dtype: torch.dtype):
|
||||||
shape_numel = prod(shape)
|
shape_numel = prod(shape)
|
||||||
if self.buffer is None or self.buffer.numel() < shape_numel:
|
if (self.buffer is None or self.buffer.numel() < shape_numel
|
||||||
|
or self.buffer.device != device or self.buffer.dtype != dtype):
|
||||||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||||
assert self.buffer.device == device, \
|
|
||||||
f"Buffer device mismatch: {self.buffer.device} != {device}"
|
|
||||||
assert self.buffer.dtype == dtype, \
|
|
||||||
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
|
|
||||||
return self.buffer[:shape_numel].view(*shape)
|
return self.buffer[:shape_numel].view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user