diff --git a/tests/test_config.py b/tests/test_config.py index dffea9138222..ce383e1b420a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,6 +6,7 @@ from typing import Literal, Union import pytest +from vllm.compilation.backends import VllmBackend from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, config, get_field) from vllm.model_executor.layers.pooler import PoolingType @@ -44,6 +45,18 @@ def test_config(test_config, expected_error): config(test_config) +def test_compile_config_repr_succeeds(): + # setup: VllmBackend mutates the config object + config = VllmConfig() + backend = VllmBackend(config) + backend.configure_post_pass() + + # test that repr(config) succeeds + val = repr(config) + assert 'VllmConfig' in val + assert 'inductor_passes' in val + + def test_get_field(): @dataclass diff --git a/vllm/config.py b/vllm/config.py index cd6ac4f89890..31a1d208eaa7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -648,7 +648,7 @@ class ModelConfig: def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: """Pull model/tokenizer from S3 to temporary directory when needed. - + Args: model: Model name or path tokenizer: Tokenizer name or path @@ -1370,9 +1370,9 @@ class ModelConfig: def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" """ - For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to + For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to True to enable cross-attention - Neuron needs all multimodal data to be in the decoder and does not + Neuron needs all multimodal data to be in the decoder and does not need to explicitly enable cross-attention """ if (current_platform.is_neuron() @@ -1794,7 +1794,7 @@ class ParallelConfig: """Global rank in distributed setup.""" enable_multimodal_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. + """ Use data parallelism instead of tensor parallelism for vision encoder. Only support LLama4 for now""" @property @@ -2272,9 +2272,9 @@ class DeviceConfig: device: SkipValidation[Union[Device, torch.device]] = "auto" """Device type for vLLM execution. - This parameter is deprecated and will be - removed in a future release. - It will now be set automatically based + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based on the current platform.""" device_type: str = field(init=False) """Device type from the current platform. This is set in @@ -4007,19 +4007,24 @@ class CompilationConfig: def __repr__(self) -> str: exclude = { - "static_forward_context", - "enabled_custom_ops", - "disabled_custom_ops", - "compilation_time", - "bs_to_padded_graph_size", - "pass_config", - "traced_files", + "static_forward_context": True, + "enabled_custom_ops": True, + "disabled_custom_ops": True, + "compilation_time": True, + "bs_to_padded_graph_size": True, + "pass_config": True, + "traced_files": True, + "inductor_compile_config": { + "post_grad_custom_post_pass": True, + }, } # The cast to string is necessary because Pydantic is mocked in docs # builds and sphinx-argparse doesn't know the return type of decode() return str( TypeAdapter(CompilationConfig).dump_json( - self, exclude=exclude, exclude_unset=True).decode()) + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode()) __str__ = __repr__