mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 10:04:59 +08:00
Fix CompilationConfig repr (#19091)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
parent
65c69444b1
commit
da511d54d8
@ -6,6 +6,7 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.compilation.backends import VllmBackend
|
||||||
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
||||||
config, get_field)
|
config, get_field)
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
@ -44,6 +45,18 @@ def test_config(test_config, expected_error):
|
|||||||
config(test_config)
|
config(test_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_config_repr_succeeds():
|
||||||
|
# setup: VllmBackend mutates the config object
|
||||||
|
config = VllmConfig()
|
||||||
|
backend = VllmBackend(config)
|
||||||
|
backend.configure_post_pass()
|
||||||
|
|
||||||
|
# test that repr(config) succeeds
|
||||||
|
val = repr(config)
|
||||||
|
assert 'VllmConfig' in val
|
||||||
|
assert 'inductor_passes' in val
|
||||||
|
|
||||||
|
|
||||||
def test_get_field():
|
def test_get_field():
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -648,7 +648,7 @@ class ModelConfig:
|
|||||||
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
||||||
tokenizer: str) -> None:
|
tokenizer: str) -> None:
|
||||||
"""Pull model/tokenizer from S3 to temporary directory when needed.
|
"""Pull model/tokenizer from S3 to temporary directory when needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model name or path
|
model: Model name or path
|
||||||
tokenizer: Tokenizer name or path
|
tokenizer: Tokenizer name or path
|
||||||
@ -1370,9 +1370,9 @@ class ModelConfig:
|
|||||||
def is_encoder_decoder(self) -> bool:
|
def is_encoder_decoder(self) -> bool:
|
||||||
"""Extract the HF encoder/decoder model flag."""
|
"""Extract the HF encoder/decoder model flag."""
|
||||||
"""
|
"""
|
||||||
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
||||||
True to enable cross-attention
|
True to enable cross-attention
|
||||||
Neuron needs all multimodal data to be in the decoder and does not
|
Neuron needs all multimodal data to be in the decoder and does not
|
||||||
need to explicitly enable cross-attention
|
need to explicitly enable cross-attention
|
||||||
"""
|
"""
|
||||||
if (current_platform.is_neuron()
|
if (current_platform.is_neuron()
|
||||||
@ -1794,7 +1794,7 @@ class ParallelConfig:
|
|||||||
"""Global rank in distributed setup."""
|
"""Global rank in distributed setup."""
|
||||||
|
|
||||||
enable_multimodal_encoder_data_parallel: bool = False
|
enable_multimodal_encoder_data_parallel: bool = False
|
||||||
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
||||||
Only support LLama4 for now"""
|
Only support LLama4 for now"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -2272,9 +2272,9 @@ class DeviceConfig:
|
|||||||
|
|
||||||
device: SkipValidation[Union[Device, torch.device]] = "auto"
|
device: SkipValidation[Union[Device, torch.device]] = "auto"
|
||||||
"""Device type for vLLM execution.
|
"""Device type for vLLM execution.
|
||||||
This parameter is deprecated and will be
|
This parameter is deprecated and will be
|
||||||
removed in a future release.
|
removed in a future release.
|
||||||
It will now be set automatically based
|
It will now be set automatically based
|
||||||
on the current platform."""
|
on the current platform."""
|
||||||
device_type: str = field(init=False)
|
device_type: str = field(init=False)
|
||||||
"""Device type from the current platform. This is set in
|
"""Device type from the current platform. This is set in
|
||||||
@ -4007,19 +4007,24 @@ class CompilationConfig:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
exclude = {
|
exclude = {
|
||||||
"static_forward_context",
|
"static_forward_context": True,
|
||||||
"enabled_custom_ops",
|
"enabled_custom_ops": True,
|
||||||
"disabled_custom_ops",
|
"disabled_custom_ops": True,
|
||||||
"compilation_time",
|
"compilation_time": True,
|
||||||
"bs_to_padded_graph_size",
|
"bs_to_padded_graph_size": True,
|
||||||
"pass_config",
|
"pass_config": True,
|
||||||
"traced_files",
|
"traced_files": True,
|
||||||
|
"inductor_compile_config": {
|
||||||
|
"post_grad_custom_post_pass": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
# The cast to string is necessary because Pydantic is mocked in docs
|
# The cast to string is necessary because Pydantic is mocked in docs
|
||||||
# builds and sphinx-argparse doesn't know the return type of decode()
|
# builds and sphinx-argparse doesn't know the return type of decode()
|
||||||
return str(
|
return str(
|
||||||
TypeAdapter(CompilationConfig).dump_json(
|
TypeAdapter(CompilationConfig).dump_json(
|
||||||
self, exclude=exclude, exclude_unset=True).decode())
|
self,
|
||||||
|
exclude=exclude, # type: ignore[arg-type]
|
||||||
|
exclude_unset=True).decode())
|
||||||
|
|
||||||
__str__ = __repr__
|
__str__ = __repr__
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user