fix float16 support for kimi-vl (#17156)

Co-authored-by: zhouzaida <zhouzaida@msh.team>
This commit is contained in:
Zaida Zhou 2025-04-25 11:16:32 +08:00 committed by GitHub
parent 41ca7eb491
commit 69bff9bc89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -340,8 +340,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
else:
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size)
# fp32 -> bf16
pixel_values = pixel_values.to(torch.bfloat16)
pixel_values = pixel_values.to(self.vision_tower.dtype)
# image_grid_hws.shape = (N, 2)
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"