mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 04:00:54 +08:00
[Bugfix] Add logs for all model dtype casting (#4717)
This commit is contained in:
parent
cea64430f6
commit
be0c5180ac
@ -1063,6 +1063,7 @@ def _get_and_verify_dtype(
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
logger.info("Casting torch.float32 to torch.float16.")
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
@ -1087,9 +1088,11 @@ def _get_and_verify_dtype(
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user