[Fix] Support passing args to logger (#17425)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2025-04-30 11:06:58 -04:00 committed by GitHub
parent 39317cf42b
commit da4e7687b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 75 additions and 79 deletions

View File

@ -278,7 +278,7 @@ class ModelConfig:
max_model_len: int = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable
format. Examples:\n
- 1k -> 1000\n
@ -518,11 +518,11 @@ class ModelConfig:
self.hf_text_config.sliding_window)
logger.warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
f"{backend} backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f"({sliding_window_len_min}).")
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
self.hf_text_config.model_type,
backend,
sliding_window_len_min,
)
self.disable_sliding_window = True
else:
# for a model with interleaved attention,

View File

@ -5,6 +5,7 @@ import json
import logging
import os
import sys
from collections.abc import Hashable
from functools import lru_cache, partial
from logging import Logger
from logging.config import dictConfig
@ -52,15 +53,15 @@ DEFAULT_LOGGING_CONFIG = {
@lru_cache
def _print_info_once(logger: Logger, msg: str) -> None:
def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.info(msg, stacklevel=2)
logger.info(msg, *args, stacklevel=2)
@lru_cache
def _print_warning_once(logger: Logger, msg: str) -> None:
def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.warning(msg, stacklevel=2)
logger.warning(msg, *args, stacklevel=2)
class _VllmLogger(Logger):
@ -72,19 +73,19 @@ class _VllmLogger(Logger):
`intel_extension_for_pytorch.utils._logger`.
"""
def info_once(self, msg: str) -> None:
def info_once(self, msg: str, *args: Hashable) -> None:
"""
As :meth:`info`, but subsequent calls with the same message
are silently dropped.
"""
_print_info_once(self, msg)
_print_info_once(self, msg, *args)
def warning_once(self, msg: str) -> None:
def warning_once(self, msg: str, *args: Hashable) -> None:
"""
As :meth:`warning`, but subsequent calls with the same message
are silently dropped.
"""
_print_warning_once(self, msg)
_print_warning_once(self, msg, *args)
def _configure_vllm_root_logger() -> None:

View File

@ -15,6 +15,5 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
assert punica_wrapper is not None, \
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
".")
logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1])
return punica_wrapper

View File

@ -107,9 +107,9 @@ class CustomOp(nn.Module):
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
f"Custom op {cls.__name__} was not registered, "
f"which means it won't appear in the op registry. "
f"It will be enabled/disabled based on the global settings.")
"Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501
cls.__name__,
)
return CustomOp.default_on()
enabled = f"+{cls.name}" in custom_ops

View File

@ -191,9 +191,9 @@ class GrammarConfig:
if model_with_warn is not None and any_whitespace:
logger.info_once(
f"{model_with_warn} model detected, consider setting "
"`disable_any_whitespace` to prevent runaway generation "
"of whitespaces.")
"%s model detected, consider setting `disable_any_whitespace` to prevent runaway generation of whitespaces.", # noqa: E501
model_with_warn,
)
# Validate the schema and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.

View File

@ -130,8 +130,9 @@ class AWQMarlinConfig(QuantizationConfig):
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels.")
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
prefix,
)
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)

View File

@ -464,7 +464,7 @@ def fastsafetensors_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
@ -716,10 +716,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
logger.warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
"not loaded.")
"Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501
name,
remapped_name,
)
return None
return remapped_name
@ -738,10 +738,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
logger.warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "
"not loaded.")
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
scale_name,
name,
remapped_name,
scale_name,
)
return None
return remapped_name

View File

@ -1111,10 +1111,10 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. "
f"{name}), but not found the expected name in "
f"the model (e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name

View File

@ -385,11 +385,10 @@ class OlmoeModel(nn.Module):
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name

View File

@ -462,11 +462,10 @@ class Qwen2MoeModel(nn.Module):
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
"Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name

View File

@ -459,11 +459,10 @@ class Qwen3MoeModel(nn.Module):
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name

View File

@ -215,17 +215,14 @@ class MultiModalProfiler(Generic[_I]):
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
"The encoder sequence length used for profiling ("
f"max_num_batched_tokens / max_num_seqs = {seq_len}) "
" is too short "
"to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which "
f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.")
"The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
seq_len,
total_len,
str(self._get_mm_num_tokens(mm_inputs)),
)
return DummyEncoderData(encoder_prompt_token_ids)
@ -243,17 +240,14 @@ class MultiModalProfiler(Generic[_I]):
if total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
"The sequence length used for profiling ("
f"max_num_batched_tokens / max_num_seqs = {seq_len}) "
"is too short "
"to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which "
f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.")
"The sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501
"is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501
"This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501
"To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501
seq_len,
total_len,
str(self._get_mm_num_tokens(mm_inputs)),
)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))

View File

@ -100,7 +100,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if not model_config.is_multimodal_model:
@ -126,11 +126,11 @@ class MultiModalRegistry:
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.
Note:
This is currently directly used only in V1 for profiling the memory
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
@ -316,7 +316,9 @@ class MultiModalRegistry:
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
logger.warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(token_ids)} tokens instead.")
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
seq_len,
len(token_ids),
)
return dummy_data