Fix CompilationConfig repr (#19091)

Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
Richard Zou 2025-06-06 04:23:35 -04:00 committed by GitHub
parent 65c69444b1
commit da511d54d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 15 deletions

View File

@ -6,6 +6,7 @@ from typing import Literal, Union
import pytest import pytest
from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field) config, get_field)
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
@ -44,6 +45,18 @@ def test_config(test_config, expected_error):
config(test_config) config(test_config)
def test_compile_config_repr_succeeds():
# setup: VllmBackend mutates the config object
config = VllmConfig()
backend = VllmBackend(config)
backend.configure_post_pass()
# test that repr(config) succeeds
val = repr(config)
assert 'VllmConfig' in val
assert 'inductor_passes' in val
def test_get_field(): def test_get_field():
@dataclass @dataclass

View File

@ -4007,19 +4007,24 @@ class CompilationConfig:
def __repr__(self) -> str: def __repr__(self) -> str:
exclude = { exclude = {
"static_forward_context", "static_forward_context": True,
"enabled_custom_ops", "enabled_custom_ops": True,
"disabled_custom_ops", "disabled_custom_ops": True,
"compilation_time", "compilation_time": True,
"bs_to_padded_graph_size", "bs_to_padded_graph_size": True,
"pass_config", "pass_config": True,
"traced_files", "traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
},
} }
# The cast to string is necessary because Pydantic is mocked in docs # The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode() # builds and sphinx-argparse doesn't know the return type of decode()
return str( return str(
TypeAdapter(CompilationConfig).dump_json( TypeAdapter(CompilationConfig).dump_json(
self, exclude=exclude, exclude_unset=True).decode()) self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
__str__ = __repr__ __str__ = __repr__