diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 33799b58d1998..efaa9cc058e41 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -506,12 +506,9 @@ class SharedResizableBuffer: def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): shape_numel = prod(shape) - if self.buffer is None or self.buffer.numel() < shape_numel: + if (self.buffer is None or self.buffer.numel() < shape_numel + or self.buffer.device != device or self.buffer.dtype != dtype): self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) - assert self.buffer.device == device, \ - f"Buffer device mismatch: {self.buffer.device} != {device}" - assert self.buffer.dtype == dtype, \ - f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}" return self.buffer[:shape_numel].view(*shape)