Fix dtype check in get_fp8_min_max helper

Address review feedback: Only apply the 224.0 override when both:
1. Platform supports fnuz (is_fp8_fnuz())
2. The dtype is actually torch.float8_e4m3fnuz

This prevents incorrect min/max values when a non-fnuz dtype is
explicitly passed on a platform that supports fnuz.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
c0de128 2025-12-21 21:06:01 -06:00
parent 961a5ab423
commit 523dd610cb

View File

@ -36,7 +36,8 @@ def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]:
if dtype is None:
dtype = FP8_DTYPE
finfo = torch.finfo(dtype)
if current_platform.is_fp8_fnuz():
# Only apply the 224.0 override for the actual fnuz dtype on fnuz platform
if current_platform.is_fp8_fnuz() and dtype == torch.float8_e4m3fnuz:
return -224.0, 224.0
return finfo.min, finfo.max