mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:15:40 +08:00
[Misc] Move config fields to MultiModalConfig (#17343)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
cde384cd92
commit
ebb3930d28
@ -263,6 +263,10 @@ class ModelConfig:
|
|||||||
the model name will be the same as `model`.
|
the model name will be the same as `model`.
|
||||||
limit_mm_per_prompt: Maximum number of data items per modality
|
limit_mm_per_prompt: Maximum number of data items per modality
|
||||||
per prompt. Only applicable for multimodal models.
|
per prompt. Only applicable for multimodal models.
|
||||||
|
mm_processor_kwargs: Overrides for the multi-modal processor obtained
|
||||||
|
from `AutoProcessor.from_pretrained`.
|
||||||
|
disable_mm_preprocessor_cache: If True, disable caching of the
|
||||||
|
processed multi-modal inputs.
|
||||||
use_async_output_proc: Whether to use async output processor.
|
use_async_output_proc: Whether to use async output processor.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
config_format: The config format which shall be loaded.
|
config_format: The config format which shall be loaded.
|
||||||
@ -273,10 +277,6 @@ class ModelConfig:
|
|||||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||||
HuggingFace config. If a callable, it is called to update the
|
HuggingFace config. If a callable, it is called to update the
|
||||||
HuggingFace config.
|
HuggingFace config.
|
||||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
|
||||||
for multi-modal data, e.g., image processor.
|
|
||||||
disable_mm_preprocessor_cache: If true, then disables caching of the
|
|
||||||
multi-modal preprocessor/mapper. (not recommended)
|
|
||||||
override_neuron_config: Initialize non default neuron config or
|
override_neuron_config: Initialize non default neuron config or
|
||||||
override default neuron config that are specific to Neuron devices,
|
override default neuron config that are specific to Neuron devices,
|
||||||
this argument will be used to configure the neuron config that
|
this argument will be used to configure the neuron config that
|
||||||
@ -320,7 +320,6 @@ class ModelConfig:
|
|||||||
factors.append(self.max_logprobs)
|
factors.append(self.max_logprobs)
|
||||||
factors.append(self.disable_sliding_window)
|
factors.append(self.disable_sliding_window)
|
||||||
factors.append(self.trust_remote_code)
|
factors.append(self.trust_remote_code)
|
||||||
factors.append(self.mm_processor_kwargs)
|
|
||||||
factors.append(self.generation_config)
|
factors.append(self.generation_config)
|
||||||
factors.append(self.model_impl)
|
factors.append(self.model_impl)
|
||||||
factors.append(self.override_generation_config)
|
factors.append(self.override_generation_config)
|
||||||
@ -359,12 +358,12 @@ class ModelConfig:
|
|||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
served_model_name: Optional[Union[str, list[str]]] = None,
|
served_model_name: Optional[Union[str, list[str]]] = None,
|
||||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||||
|
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
disable_mm_preprocessor_cache: bool = False,
|
||||||
use_async_output_proc: bool = True,
|
use_async_output_proc: bool = True,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
hf_token: Optional[Union[bool, str]] = None,
|
hf_token: Optional[Union[bool, str]] = None,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
hf_overrides: Optional[HfOverrides] = None,
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
disable_mm_preprocessor_cache: bool = False,
|
|
||||||
override_neuron_config: Optional[dict[str, Any]] = None,
|
override_neuron_config: Optional[dict[str, Any]] = None,
|
||||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||||
logits_processor_pattern: Optional[str] = None,
|
logits_processor_pattern: Optional[str] = None,
|
||||||
@ -469,8 +468,6 @@ class ModelConfig:
|
|||||||
self.model, hf_token=hf_token, revision=revision)
|
self.model, hf_token=hf_token, revision=revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
self.use_async_output_proc = use_async_output_proc
|
self.use_async_output_proc = use_async_output_proc
|
||||||
self.mm_processor_kwargs = mm_processor_kwargs
|
|
||||||
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
|
|
||||||
|
|
||||||
# Set enforce_eager to False if the value is unset.
|
# Set enforce_eager to False if the value is unset.
|
||||||
if self.enforce_eager is None:
|
if self.enforce_eager is None:
|
||||||
@ -515,7 +512,10 @@ class ModelConfig:
|
|||||||
self.served_model_name = get_served_model_name(model,
|
self.served_model_name = get_served_model_name(model,
|
||||||
served_model_name)
|
served_model_name)
|
||||||
self.multimodal_config = self._init_multimodal_config(
|
self.multimodal_config = self._init_multimodal_config(
|
||||||
limit_mm_per_prompt)
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||||
|
)
|
||||||
if not self.skip_tokenizer_init:
|
if not self.skip_tokenizer_init:
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|
||||||
@ -581,14 +581,27 @@ class ModelConfig:
|
|||||||
self.tokenizer = s3_tokenizer.dir
|
self.tokenizer = s3_tokenizer.dir
|
||||||
|
|
||||||
def _init_multimodal_config(
|
def _init_multimodal_config(
|
||||||
self, limit_mm_per_prompt: Optional[dict[str, int]]
|
self,
|
||||||
|
limit_mm_per_prompt: Optional[dict[str, int]],
|
||||||
|
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||||
|
disable_mm_preprocessor_cache: bool,
|
||||||
) -> Optional["MultiModalConfig"]:
|
) -> Optional["MultiModalConfig"]:
|
||||||
if self.registry.is_multimodal_model(self.architectures):
|
if self.registry.is_multimodal_model(self.architectures):
|
||||||
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
|
return MultiModalConfig(
|
||||||
|
limit_per_prompt=limit_mm_per_prompt or {},
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs or {},
|
||||||
|
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||||
|
)
|
||||||
|
|
||||||
if limit_mm_per_prompt:
|
if limit_mm_per_prompt:
|
||||||
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
||||||
"multimodal models.")
|
"multimodal models.")
|
||||||
|
if mm_processor_kwargs:
|
||||||
|
raise ValueError("`mm_processor_kwargs` is only supported for "
|
||||||
|
"multimodal models.")
|
||||||
|
if disable_mm_preprocessor_cache:
|
||||||
|
raise ValueError("`disable_mm_preprocessor_cache` is only "
|
||||||
|
"supported for multimodal models.")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -2776,7 +2789,23 @@ class MultiModalConfig:
|
|||||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||||
|
|
||||||
For example, to allow up to 16 images and 2 videos per prompt:
|
For example, to allow up to 16 images and 2 videos per prompt:
|
||||||
``{"images": 16, "videos": 2}``
|
:code:`{"images": 16, "videos": 2}`
|
||||||
|
"""
|
||||||
|
|
||||||
|
mm_processor_kwargs: Optional[dict[str, object]] = None
|
||||||
|
"""
|
||||||
|
Overrides for the multi-modal processor obtained from
|
||||||
|
:meth:`transformers.AutoProcessor.from_pretrained`.
|
||||||
|
|
||||||
|
The available overrides depend on the model that is being run.
|
||||||
|
|
||||||
|
For example, for Phi-3-Vision:
|
||||||
|
:code:`{"num_crops": 4}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
disable_mm_preprocessor_cache: bool = False
|
||||||
|
"""
|
||||||
|
If :code:`True`, disable caching of the processed multi-modal inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
@ -4080,8 +4109,6 @@ class VllmConfig:
|
|||||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||||
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
||||||
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
||||||
f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
|
|
||||||
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
|
|
||||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||||
f"compilation_config={self.compilation_config!r}")
|
f"compilation_config={self.compilation_config!r}")
|
||||||
|
|
||||||
|
|||||||
@ -672,20 +672,12 @@ class EngineArgs:
|
|||||||
)
|
)
|
||||||
multimodal_group.add_argument('--limit-mm-per-prompt',
|
multimodal_group.add_argument('--limit-mm-per-prompt',
|
||||||
**multimodal_kwargs["limit_per_prompt"])
|
**multimodal_kwargs["limit_per_prompt"])
|
||||||
|
multimodal_group.add_argument(
|
||||||
parser.add_argument(
|
|
||||||
'--mm-processor-kwargs',
|
'--mm-processor-kwargs',
|
||||||
default=None,
|
**multimodal_kwargs["mm_processor_kwargs"])
|
||||||
type=json.loads,
|
multimodal_group.add_argument(
|
||||||
help=('Overrides for the multi-modal processor obtained from '
|
|
||||||
'``AutoProcessor.from_pretrained``. The available overrides '
|
|
||||||
'depend on the model that is being run.'
|
|
||||||
'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
|
|
||||||
parser.add_argument(
|
|
||||||
'--disable-mm-preprocessor-cache',
|
'--disable-mm-preprocessor-cache',
|
||||||
action='store_true',
|
**multimodal_kwargs["disable_mm_preprocessor_cache"])
|
||||||
help='If True, disable caching of the processed multi-modal '
|
|
||||||
'inputs.')
|
|
||||||
|
|
||||||
# LoRA related configs
|
# LoRA related configs
|
||||||
lora_kwargs = get_kwargs(LoRAConfig)
|
lora_kwargs = get_kwargs(LoRAConfig)
|
||||||
|
|||||||
@ -101,7 +101,8 @@ class InputContext:
|
|||||||
Initialize a HuggingFace-like processor class, merging the
|
Initialize a HuggingFace-like processor class, merging the
|
||||||
keyword arguments with those in the model's configuration.
|
keyword arguments with those in the model's configuration.
|
||||||
"""
|
"""
|
||||||
base_kwargs = self.model_config.mm_processor_kwargs
|
mm_config = self.model_config.get_multimodal_config()
|
||||||
|
base_kwargs = mm_config.mm_processor_kwargs
|
||||||
if base_kwargs is None:
|
if base_kwargs is None:
|
||||||
base_kwargs = {}
|
base_kwargs = {}
|
||||||
|
|
||||||
@ -139,7 +140,8 @@ class InputProcessingContext(InputContext):
|
|||||||
"""
|
"""
|
||||||
assert callable(hf_processor)
|
assert callable(hf_processor)
|
||||||
|
|
||||||
base_kwargs = self.model_config.mm_processor_kwargs
|
mm_config = self.model_config.get_multimodal_config()
|
||||||
|
base_kwargs = mm_config.mm_processor_kwargs
|
||||||
if base_kwargs is None:
|
if base_kwargs is None:
|
||||||
base_kwargs = {}
|
base_kwargs = {}
|
||||||
|
|
||||||
|
|||||||
@ -774,8 +774,9 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
|||||||
size: Optional[dict[str, int]] = None,
|
size: Optional[dict[str, int]] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
if self.ctx.model_config.mm_processor_kwargs:
|
mm_config = self.ctx.model_config.get_multimodal_config()
|
||||||
kwargs.update(self.ctx.model_config.mm_processor_kwargs)
|
if mm_config.mm_processor_kwargs:
|
||||||
|
kwargs.update(mm_config.mm_processor_kwargs)
|
||||||
|
|
||||||
if min_pixels is not None:
|
if min_pixels is not None:
|
||||||
kwargs["min_pixels"] = min_pixels
|
kwargs["min_pixels"] = min_pixels
|
||||||
|
|||||||
@ -262,7 +262,8 @@ class MultiModalRegistry:
|
|||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = cached_tokenizer_from_config(model_config)
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
if disable_cache is None:
|
if disable_cache is None:
|
||||||
disable_cache = model_config.disable_mm_preprocessor_cache
|
mm_config = model_config.get_multimodal_config()
|
||||||
|
disable_cache = mm_config.disable_mm_preprocessor_cache
|
||||||
|
|
||||||
model_cls = self._get_model_cls(model_config)
|
model_cls = self._get_model_cls(model_config)
|
||||||
factories = self._processor_factories[model_cls]
|
factories = self._processor_factories[model_cls]
|
||||||
|
|||||||
@ -33,7 +33,8 @@ class HashableList(list):
|
|||||||
|
|
||||||
|
|
||||||
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
|
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
|
||||||
base_kwargs = model_config.mm_processor_kwargs
|
mm_config = model_config.get_multimodal_config()
|
||||||
|
base_kwargs = mm_config.mm_processor_kwargs
|
||||||
if base_kwargs is None:
|
if base_kwargs is None:
|
||||||
base_kwargs = {}
|
base_kwargs = {}
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,10 @@ from vllm.utils import is_list_of
|
|||||||
class MirroredProcessingCache:
|
class MirroredProcessingCache:
|
||||||
|
|
||||||
def __init__(self, model_config):
|
def __init__(self, model_config):
|
||||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
mm_config = model_config.multimodal_config
|
||||||
|
disable_mm_preprocessor_cache = mm_config is not None and \
|
||||||
|
not mm_config.disable_mm_preprocessor_cache
|
||||||
|
self.use_cache = not disable_mm_preprocessor_cache
|
||||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||||
MultiModalKwargs)
|
MultiModalKwargs)
|
||||||
|
|
||||||
|
|||||||
@ -51,8 +51,7 @@ class Processor:
|
|||||||
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
||||||
|
|
||||||
# Multi-modal hasher (for images)
|
# Multi-modal hasher (for images)
|
||||||
self.use_hash = (
|
self.use_hash = self.mm_input_cache_client.use_cache or \
|
||||||
not self.model_config.disable_mm_preprocessor_cache) or \
|
|
||||||
self.cache_config.enable_prefix_caching
|
self.cache_config.enable_prefix_caching
|
||||||
|
|
||||||
def _validate_logprobs(
|
def _validate_logprobs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user