[CI][Bugfix] Fix failing Blackwell test (#24993)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni 2025-09-16 18:55:02 -04:00 committed by GitHub
parent dbebb7f812
commit d119fc8614
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -506,12 +506,9 @@ class SharedResizableBuffer:
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel:
if (self.buffer is None or self.buffer.numel() < shape_numel
or self.buffer.device != device or self.buffer.dtype != dtype):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
assert self.buffer.device == device, \
f"Buffer device mismatch: {self.buffer.device} != {device}"
assert self.buffer.dtype == dtype, \
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
return self.buffer[:shape_numel].view(*shape)