mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:55:01 +08:00
Remove default values from InitVars so that they're not stored (#29859)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
d8c6210eea
commit
951445a52d
@ -108,7 +108,10 @@ def benchmark_batched_propose(args):
|
|||||||
device_config=DeviceConfig(device=current_platform.device_type),
|
device_config=DeviceConfig(device=current_platform.device_type),
|
||||||
parallel_config=ParallelConfig(),
|
parallel_config=ParallelConfig(),
|
||||||
load_config=LoadConfig(),
|
load_config=LoadConfig(),
|
||||||
scheduler_config=SchedulerConfig(),
|
scheduler_config=SchedulerConfig(
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
|
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
|
||||||
|
|||||||
@ -318,13 +318,18 @@ def test_attention_quant_pattern(
|
|||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=2048,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
model_config=ModelConfig(
|
model_config=model_config,
|
||||||
model=model_name,
|
scheduler_config=SchedulerConfig(
|
||||||
max_model_len=2048,
|
max_num_seqs=1024,
|
||||||
dtype=dtype,
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
),
|
),
|
||||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
custom_ops=custom_ops_list,
|
custom_ops=custom_ops_list,
|
||||||
|
|||||||
@ -33,14 +33,16 @@ def test_worker_apply_lora(qwen3_lora_files):
|
|||||||
lora_requests, lora_mapping
|
lora_requests, lora_mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
MODEL_PATH,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
max_model_len=127,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
model_config=ModelConfig(
|
model_config=model_config,
|
||||||
MODEL_PATH,
|
|
||||||
seed=0,
|
|
||||||
dtype="float16",
|
|
||||||
max_model_len=127,
|
|
||||||
enforce_eager=True,
|
|
||||||
),
|
|
||||||
load_config=LoadConfig(
|
load_config=LoadConfig(
|
||||||
download_dir=None,
|
download_dir=None,
|
||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
@ -50,7 +52,14 @@ def test_worker_apply_lora(qwen3_lora_files):
|
|||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
data_parallel_size=1,
|
data_parallel_size=1,
|
||||||
),
|
),
|
||||||
scheduler_config=SchedulerConfig("generate", 32, 32, 32),
|
scheduler_config=SchedulerConfig(
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
runner_type="generate",
|
||||||
|
max_num_batched_tokens=32,
|
||||||
|
max_num_seqs=32,
|
||||||
|
max_num_partial_prefills=32,
|
||||||
|
),
|
||||||
device_config=DeviceConfig("cuda"),
|
device_config=DeviceConfig("cuda"),
|
||||||
cache_config=CacheConfig(
|
cache_config=CacheConfig(
|
||||||
block_size=16,
|
block_size=16,
|
||||||
|
|||||||
@ -6,12 +6,14 @@ from dataclasses import MISSING, Field, asdict, dataclass, field
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
PoolerConfig,
|
PoolerConfig,
|
||||||
|
SchedulerConfig,
|
||||||
VllmConfig,
|
VllmConfig,
|
||||||
update_config,
|
update_config,
|
||||||
)
|
)
|
||||||
@ -1095,3 +1097,14 @@ def test_vllm_config_explicit_overrides():
|
|||||||
# Other fields should still use defaults
|
# Other fields should still use defaults
|
||||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_config_init():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
# Positional InitVars missing
|
||||||
|
# (InitVars cannot have defaults otherwise they will become attributes)
|
||||||
|
SchedulerConfig()
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
# InitVar does not become an attribute
|
||||||
|
print(SchedulerConfig.default_factory().max_model_len)
|
||||||
|
|||||||
@ -185,6 +185,8 @@ def create_vllm_config(
|
|||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
device_config = DeviceConfig()
|
device_config = DeviceConfig()
|
||||||
|
|||||||
@ -1128,7 +1128,11 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len)
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_batched_tokens=32768,
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -1163,7 +1167,10 @@ def test_get_max_concurrency_for_kv_cache_config():
|
|||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
max_num_batched_tokens=1024, enable_chunked_prefill=True
|
max_num_batched_tokens=1024,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
|
|||||||
@ -1508,6 +1508,12 @@ def create_scheduler_with_priority(
|
|||||||
Returns:
|
Returns:
|
||||||
{class}`Scheduler` instance with priority scheduling
|
{class}`Scheduler` instance with priority scheduling
|
||||||
"""
|
"""
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = max_num_batched_tokens
|
max_model_len = max_num_batched_tokens
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
@ -1517,14 +1523,9 @@ def create_scheduler_with_priority(
|
|||||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
policy="priority", # Enable priority scheduling
|
policy="priority", # Enable priority scheduling
|
||||||
)
|
)
|
||||||
model_config = ModelConfig(
|
|
||||||
model=model,
|
|
||||||
trust_remote_code=True,
|
|
||||||
dtype="float16",
|
|
||||||
seed=42,
|
|
||||||
)
|
|
||||||
# Cache config, optionally force APC
|
# Cache config, optionally force APC
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
|||||||
@ -69,6 +69,13 @@ def create_scheduler(
|
|||||||
Returns:
|
Returns:
|
||||||
{class}`Scheduler` instance
|
{class}`Scheduler` instance
|
||||||
"""
|
"""
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
|
)
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = max_num_batched_tokens
|
max_model_len = max_num_batched_tokens
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
@ -79,13 +86,7 @@ def create_scheduler(
|
|||||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
async_scheduling=async_scheduling,
|
async_scheduling=async_scheduling,
|
||||||
)
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
model_config = ModelConfig(
|
|
||||||
model=model,
|
|
||||||
trust_remote_code=True,
|
|
||||||
dtype="float16",
|
|
||||||
seed=42,
|
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
|
||||||
)
|
)
|
||||||
# Cache config, optionally force APC
|
# Cache config, optionally force APC
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
|
|||||||
@ -40,7 +40,9 @@ def _create_vllm_config(
|
|||||||
) -> MagicMock:
|
) -> MagicMock:
|
||||||
mock_config = MagicMock(spec=VllmConfig)
|
mock_config = MagicMock(spec=VllmConfig)
|
||||||
mock_config.compilation_config = compilation_config
|
mock_config.compilation_config = compilation_config
|
||||||
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
mock_config.scheduler_config = SchedulerConfig.default_factory(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
)
|
||||||
mock_config.parallel_config = ParallelConfig()
|
mock_config.parallel_config = ParallelConfig()
|
||||||
mock_config.speculative_config = None # No speculative decoding
|
mock_config.speculative_config = None # No speculative decoding
|
||||||
if not lora_config:
|
if not lora_config:
|
||||||
|
|||||||
@ -484,12 +484,6 @@ def test_encoder_instance_zero_kv_cache(
|
|||||||
vision encoder, so they don't need KV cache for text generation.
|
vision encoder, so they don't need KV cache for text generation.
|
||||||
"""
|
"""
|
||||||
# Form vllm config
|
# Form vllm config
|
||||||
scheduler_config = SchedulerConfig(
|
|
||||||
max_num_seqs=10,
|
|
||||||
max_num_batched_tokens=512,
|
|
||||||
max_model_len=512,
|
|
||||||
disable_hybrid_kv_cache_manager=True,
|
|
||||||
)
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model="llava-hf/llava-1.5-7b-hf", # Multimodal model
|
model="llava-hf/llava-1.5-7b-hf", # Multimodal model
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
@ -497,6 +491,13 @@ def test_encoder_instance_zero_kv_cache(
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=10,
|
||||||
|
max_num_batched_tokens=512,
|
||||||
|
max_model_len=512,
|
||||||
|
disable_hybrid_kv_cache_manager=True,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=16,
|
block_size=16,
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
|||||||
@ -92,18 +92,19 @@ def create_vllm_config(
|
|||||||
enable_permute_local_kv: bool = False,
|
enable_permute_local_kv: bool = False,
|
||||||
) -> VllmConfig:
|
) -> VllmConfig:
|
||||||
"""Initialize VllmConfig For Testing."""
|
"""Initialize VllmConfig For Testing."""
|
||||||
scheduler_config = SchedulerConfig(
|
|
||||||
max_num_seqs=max_num_seqs,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
|
||||||
)
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=model,
|
model=model,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
)
|
||||||
# Cache config, optionally force APC
|
# Cache config, optionally force APC
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
|||||||
@ -66,7 +66,10 @@ def _create_proposer(
|
|||||||
device_config=DeviceConfig(device=current_platform.device_type),
|
device_config=DeviceConfig(device=current_platform.device_type),
|
||||||
parallel_config=ParallelConfig(),
|
parallel_config=ParallelConfig(),
|
||||||
load_config=LoadConfig(),
|
load_config=LoadConfig(),
|
||||||
scheduler_config=SchedulerConfig(),
|
scheduler_config=SchedulerConfig(
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||||
|
|||||||
@ -51,7 +51,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
|||||||
device_config=DeviceConfig(device=current_platform.device_type),
|
device_config=DeviceConfig(device=current_platform.device_type),
|
||||||
parallel_config=ParallelConfig(),
|
parallel_config=ParallelConfig(),
|
||||||
load_config=LoadConfig(),
|
load_config=LoadConfig(),
|
||||||
scheduler_config=SchedulerConfig(),
|
scheduler_config=SchedulerConfig(
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||||
|
|||||||
@ -26,16 +26,17 @@ from vllm.v1.worker.tpu_model_runner import (
|
|||||||
|
|
||||||
|
|
||||||
def get_vllm_config():
|
def get_vllm_config():
|
||||||
scheduler_config = SchedulerConfig(
|
|
||||||
max_num_seqs=10,
|
|
||||||
max_num_batched_tokens=512,
|
|
||||||
max_model_len=512,
|
|
||||||
)
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
dtype="bfloat16", # TPUs typically use bfloat16
|
dtype="bfloat16", # TPUs typically use bfloat16
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=10,
|
||||||
|
max_num_batched_tokens=512,
|
||||||
|
max_model_len=512,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=16,
|
block_size=16,
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
|
|||||||
@ -79,16 +79,17 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
|||||||
|
|
||||||
|
|
||||||
def get_vllm_config():
|
def get_vllm_config():
|
||||||
scheduler_config = SchedulerConfig(
|
|
||||||
max_num_seqs=10,
|
|
||||||
max_num_batched_tokens=512,
|
|
||||||
max_model_len=512,
|
|
||||||
)
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=10,
|
||||||
|
max_num_batched_tokens=512,
|
||||||
|
max_model_len=512,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
@ -784,14 +785,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
|||||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model="ibm-granite/granite-4.0-tiny-preview",
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
max_num_seqs=10,
|
max_num_seqs=10,
|
||||||
max_num_batched_tokens=512,
|
max_num_batched_tokens=512,
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
)
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
model_config = ModelConfig(
|
|
||||||
model="ibm-granite/granite-4.0-tiny-preview",
|
|
||||||
dtype="float16",
|
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
|
|||||||
@ -28,6 +28,19 @@ SchedulerPolicy = Literal["fcfs", "priority"]
|
|||||||
class SchedulerConfig:
|
class SchedulerConfig:
|
||||||
"""Scheduler configuration."""
|
"""Scheduler configuration."""
|
||||||
|
|
||||||
|
max_model_len: InitVar[int]
|
||||||
|
"""Maximum length of a sequence (including prompt and generated text).
|
||||||
|
|
||||||
|
Note: This is stored in the ModelConfig, and is used only here to
|
||||||
|
provide fallbacks and validate other attributes."""
|
||||||
|
|
||||||
|
is_encoder_decoder: InitVar[bool]
|
||||||
|
"""True if the model is an encoder-decoder model.
|
||||||
|
|
||||||
|
Note: This is stored in the ModelConfig, and is used only here to
|
||||||
|
disable chunked prefill and prefix caching for encoder-decoder models.
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
|
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
|
||||||
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
|
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
|
||||||
|
|
||||||
@ -73,19 +86,6 @@ class SchedulerConfig:
|
|||||||
is_multimodal_model: bool = False
|
is_multimodal_model: bool = False
|
||||||
"""True if the model is multimodal."""
|
"""True if the model is multimodal."""
|
||||||
|
|
||||||
max_model_len: InitVar[int] = 8192
|
|
||||||
"""Maximum length of a sequence (including prompt and generated text).
|
|
||||||
|
|
||||||
Note: This is stored in the ModelConfig, and is used only here to
|
|
||||||
provide fallbacks and validate other attributes."""
|
|
||||||
|
|
||||||
is_encoder_decoder: InitVar[bool] = False
|
|
||||||
"""True if the model is an encoder-decoder model.
|
|
||||||
|
|
||||||
Note: This is stored in the ModelConfig, and is used only here to
|
|
||||||
disable chunked prefill and prefix caching for encoder-decoder models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO (ywang96): Make this configurable.
|
# TODO (ywang96): Make this configurable.
|
||||||
max_num_encoder_input_tokens: int = Field(init=False)
|
max_num_encoder_input_tokens: int = Field(init=False)
|
||||||
"""Multimodal encoder compute budget, only used in V1.
|
"""Multimodal encoder compute budget, only used in V1.
|
||||||
@ -141,6 +141,17 @@ class SchedulerConfig:
|
|||||||
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
||||||
by batching multiple tokens before sending."""
|
by batching multiple tokens before sending."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_factory(**kwargs):
|
||||||
|
"""
|
||||||
|
Factory method to create `SchedulerConfig` with default values for `InitVar`s.
|
||||||
|
"""
|
||||||
|
if "max_model_len" not in kwargs:
|
||||||
|
kwargs["max_model_len"] = 8192
|
||||||
|
if "is_encoder_decoder" not in kwargs:
|
||||||
|
kwargs["is_encoder_decoder"] = False
|
||||||
|
return SchedulerConfig(**kwargs)
|
||||||
|
|
||||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||||
if self.scheduler_cls is None:
|
if self.scheduler_cls is None:
|
||||||
if self.async_scheduling:
|
if self.async_scheduling:
|
||||||
@ -284,8 +295,3 @@ class SchedulerConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __getattribute__(self, name: str) -> Any:
|
|
||||||
if name == "max_model_len" or name == "is_encoder_decoder":
|
|
||||||
raise AttributeError(f"{name} is an init-only parameter. ")
|
|
||||||
return object.__getattribute__(self, name)
|
|
||||||
|
|||||||
@ -170,7 +170,9 @@ class VllmConfig:
|
|||||||
"""Cache configuration."""
|
"""Cache configuration."""
|
||||||
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
|
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
|
||||||
"""Parallel configuration."""
|
"""Parallel configuration."""
|
||||||
scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig)
|
scheduler_config: SchedulerConfig = Field(
|
||||||
|
default_factory=SchedulerConfig.default_factory,
|
||||||
|
)
|
||||||
"""Scheduler configuration."""
|
"""Scheduler configuration."""
|
||||||
device_config: DeviceConfig = Field(default_factory=DeviceConfig)
|
device_config: DeviceConfig = Field(default_factory=DeviceConfig)
|
||||||
"""Device configuration."""
|
"""Device configuration."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user