mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 20:45:33 +08:00
[Deprecation] Remove deprecated task, seed and MM settings (#30397)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5a87d8b9b1
commit
7e24e5d4d6
@ -37,7 +37,7 @@ def benchmark_propose(args):
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
seed=0,
|
||||
trust_remote_code=False,
|
||||
)
|
||||
proposer = NgramProposer(
|
||||
|
||||
@ -422,7 +422,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@ -77,7 +77,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -158,7 +158,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -2031,7 +2031,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
|
||||
@ -1382,7 +1382,7 @@ def run_generate(
|
||||
model,
|
||||
question: str,
|
||||
image_urls: list[str],
|
||||
seed: int | None,
|
||||
seed: int,
|
||||
tensor_parallel_size: int | None,
|
||||
):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
@ -1416,7 +1416,7 @@ def run_chat(
|
||||
model: str,
|
||||
question: str,
|
||||
image_urls: list[str],
|
||||
seed: int | None,
|
||||
seed: int,
|
||||
tensor_parallel_size: int | None,
|
||||
):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
@ -1494,7 +1494,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@ -16,7 +16,7 @@ import requests
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --model-impl terratorch
|
||||
# --task embed --trust-remote-code
|
||||
# --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin terratorch_segmentation
|
||||
# --enable-mm-embeds
|
||||
|
||||
@ -305,7 +305,7 @@ def get_query(modality: QueryModality):
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def run_encode(model: str, modality: QueryModality, seed: int | None):
|
||||
def run_encode(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
@ -335,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: int | None):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_score(model: str, modality: QueryModality, seed: int | None):
|
||||
def run_score(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
@ -390,7 +390,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@ -741,7 +741,7 @@ class VllmRunner:
|
||||
tokenizer_name: str | None = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
max_model_len: int | None = 1024,
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
|
||||
@ -89,64 +89,6 @@ def test_update_config():
|
||||
new_config3 = update_config(config3, {"a": "new_value"})
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "generate", "none", "generate"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
|
||||
("openai/whisper-small", "generate", "none", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_auto_task(
|
||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
||||
):
|
||||
config = ModelConfig(model_id, task="auto")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "pooling", "embed", "embed"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"),
|
||||
("openai/whisper-small", "pooling", "embed", "embed"),
|
||||
],
|
||||
)
|
||||
def test_score_task(
|
||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
||||
):
|
||||
config = ModelConfig(model_id, task="score")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
||||
[
|
||||
("openai/whisper-small", "generate", "none", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_transcription_task(
|
||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
||||
):
|
||||
config = ModelConfig(model_id, task="transcription")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type"),
|
||||
[
|
||||
|
||||
@ -119,7 +119,7 @@ class RemoteOpenAIServer:
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: dict[str, str] | None = None,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: float | None = None,
|
||||
override_hf_configs: dict[str, Any] | None = None,
|
||||
@ -283,7 +283,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
||||
child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None],
|
||||
*,
|
||||
env_dict: dict[str, str] | None = None,
|
||||
seed: int | None = 0,
|
||||
seed: int = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: float | None = None,
|
||||
) -> None:
|
||||
|
||||
@ -73,17 +73,6 @@ logger = init_logger(__name__)
|
||||
RunnerOption = Literal["auto", RunnerType]
|
||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TaskOption = Literal[
|
||||
"auto",
|
||||
"generate",
|
||||
"embedding",
|
||||
"embed",
|
||||
"classify",
|
||||
"score",
|
||||
"reward",
|
||||
"transcription",
|
||||
"draft",
|
||||
]
|
||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal[
|
||||
@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
|
||||
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
|
||||
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
|
||||
"generate": ["generate", "transcription"],
|
||||
"pooling": ["embedding", "embed", "classify", "score", "reward"],
|
||||
"draft": ["draft"],
|
||||
}
|
||||
|
||||
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
||||
"generate": [],
|
||||
"pooling": ["embed", "classify", "reward"],
|
||||
@ -126,12 +109,6 @@ class ModelConfig:
|
||||
"""Convert the model using adapters defined in
|
||||
[vllm.model_executor.models.adapters][]. The most common use case is to
|
||||
adapt a text generation model to be used for pooling tasks."""
|
||||
task: TaskOption | None = None
|
||||
"""[DEPRECATED] The task to use the model for. If the model supports more
|
||||
than one model runner, this is used to select which model runner to run.
|
||||
|
||||
Note that the model may support other tasks using the same model runner.
|
||||
"""
|
||||
tokenizer: SkipValidation[str] = None # type: ignore
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
@ -335,7 +312,6 @@ class ModelConfig:
|
||||
ignored_factors = {
|
||||
"runner",
|
||||
"convert",
|
||||
"task",
|
||||
"tokenizer",
|
||||
"tokenizer_mode",
|
||||
"seed",
|
||||
@ -510,97 +486,6 @@ class ModelConfig:
|
||||
is_generative_model = registry.is_text_generation_model(architectures, self)
|
||||
is_pooling_model = registry.is_pooling_model(architectures, self)
|
||||
|
||||
def _task_to_convert(task: TaskOption) -> ConvertType:
|
||||
if task == "embedding" or task == "embed":
|
||||
return "embed"
|
||||
if task == "classify":
|
||||
return "classify"
|
||||
if task == "reward":
|
||||
logger.warning(
|
||||
"Pooling models now default support all pooling; "
|
||||
"you can use it without any settings."
|
||||
)
|
||||
return "embed"
|
||||
if task == "score":
|
||||
new_task = self._get_default_pooling_task(architectures)
|
||||
return "classify" if new_task == "classify" else "embed"
|
||||
|
||||
return "none"
|
||||
|
||||
if self.task is not None:
|
||||
runner: RunnerOption = "auto"
|
||||
convert: ConvertOption = "auto"
|
||||
msg_prefix = (
|
||||
"The 'task' option has been deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0, whichever comes first."
|
||||
)
|
||||
msg_hint = "Please remove this option."
|
||||
|
||||
is_generative_task = self.task in _RUNNER_TASKS["generate"]
|
||||
is_pooling_task = self.task in _RUNNER_TASKS["pooling"]
|
||||
|
||||
if is_generative_model and is_pooling_model:
|
||||
if is_generative_task:
|
||||
runner = "generate"
|
||||
convert = "auto"
|
||||
msg_hint = (
|
||||
"Please replace this option with `--runner "
|
||||
"generate` to continue using this model "
|
||||
"as a generative model."
|
||||
)
|
||||
elif is_pooling_task:
|
||||
runner = "pooling"
|
||||
convert = "auto"
|
||||
msg_hint = (
|
||||
"Please replace this option with `--runner "
|
||||
"pooling` to continue using this model "
|
||||
"as a pooling model."
|
||||
)
|
||||
else: # task == "auto"
|
||||
pass
|
||||
elif is_generative_model or is_pooling_model:
|
||||
if is_generative_task:
|
||||
runner = "generate"
|
||||
convert = "auto"
|
||||
msg_hint = "Please remove this option"
|
||||
elif is_pooling_task:
|
||||
runner = "pooling"
|
||||
convert = _task_to_convert(self.task)
|
||||
msg_hint = (
|
||||
"Please replace this option with `--convert "
|
||||
f"{convert}` to continue using this model "
|
||||
"as a pooling model."
|
||||
)
|
||||
else: # task == "auto"
|
||||
pass
|
||||
else:
|
||||
# Neither generative nor pooling model - try to convert if possible
|
||||
if is_pooling_task:
|
||||
runner = "pooling"
|
||||
convert = _task_to_convert(self.task)
|
||||
msg_hint = (
|
||||
"Please replace this option with `--runner pooling "
|
||||
f"--convert {convert}` to continue using this model "
|
||||
"as a pooling model."
|
||||
)
|
||||
else:
|
||||
debug_info = {
|
||||
"architectures": architectures,
|
||||
"is_generative_model": is_generative_model,
|
||||
"is_pooling_model": is_pooling_model,
|
||||
}
|
||||
raise AssertionError(
|
||||
"The model should be a generative or "
|
||||
"pooling model when task is set to "
|
||||
f"{self.task!r}. Found: {debug_info}"
|
||||
)
|
||||
|
||||
self.runner = runner
|
||||
self.convert = convert
|
||||
|
||||
msg = f"{msg_prefix} {msg_hint}"
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
self.runner_type = self._get_runner_type(architectures, self.runner)
|
||||
self.convert_type = self._get_convert_type(
|
||||
architectures, self.runner_type, self.convert
|
||||
@ -918,22 +803,6 @@ class ModelConfig:
|
||||
|
||||
return convert_type
|
||||
|
||||
def _get_default_pooling_task(
|
||||
self,
|
||||
architectures: list[str],
|
||||
) -> Literal["embed", "classify", "reward"]:
|
||||
if self.registry.is_cross_encoder_model(architectures, self):
|
||||
return "classify"
|
||||
|
||||
for arch in architectures:
|
||||
match = try_match_architecture_defaults(arch, runner_type="pooling")
|
||||
if match:
|
||||
_, (_, convert_type) = match
|
||||
assert convert_type != "none"
|
||||
return convert_type
|
||||
|
||||
return "embed"
|
||||
|
||||
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
||||
quant_cfg = getattr(hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
|
||||
@ -71,7 +71,6 @@ from vllm.config.model import (
|
||||
LogprobsMode,
|
||||
ModelDType,
|
||||
RunnerOption,
|
||||
TaskOption,
|
||||
TokenizerMode,
|
||||
)
|
||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
|
||||
@ -360,7 +359,6 @@ class EngineArgs:
|
||||
hf_config_path: str | None = ModelConfig.hf_config_path
|
||||
runner: RunnerOption = ModelConfig.runner
|
||||
convert: ConvertOption = ModelConfig.convert
|
||||
task: TaskOption | None = ModelConfig.task
|
||||
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
|
||||
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
|
||||
tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
|
||||
@ -373,7 +371,7 @@ class EngineArgs:
|
||||
config_format: str = ModelConfig.config_format
|
||||
dtype: ModelDType = ModelConfig.dtype
|
||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||
seed: int | None = 0
|
||||
seed: int = ModelConfig.seed
|
||||
max_model_len: int | None = ModelConfig.max_model_len
|
||||
cudagraph_capture_sizes: list[int] | None = (
|
||||
CompilationConfig.cudagraph_capture_sizes
|
||||
@ -462,7 +460,6 @@ class EngineArgs:
|
||||
MultiModalConfig, "media_io_kwargs"
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_processor_cache_type: MMCacheType | None = (
|
||||
MultiModalConfig.mm_processor_cache_type
|
||||
@ -558,9 +555,6 @@ class EngineArgs:
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
|
||||
# DEPRECATED
|
||||
enable_multimodal_encoder_data_parallel: bool = False
|
||||
|
||||
logits_processors: list[str | type[LogitsProcessor]] | None = (
|
||||
ModelConfig.logits_processors
|
||||
)
|
||||
@ -628,7 +622,6 @@ class EngineArgs:
|
||||
model_group.add_argument("--model", **model_kwargs["model"])
|
||||
model_group.add_argument("--runner", **model_kwargs["runner"])
|
||||
model_group.add_argument("--convert", **model_kwargs["convert"])
|
||||
model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
|
||||
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
|
||||
model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
|
||||
model_group.add_argument(
|
||||
@ -882,11 +875,6 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-multimodal-encoder-data-parallel",
|
||||
action="store_true",
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
# KV cache arguments
|
||||
cache_kwargs = get_kwargs(CacheConfig)
|
||||
@ -960,9 +948,6 @@ class EngineArgs:
|
||||
multimodal_group.add_argument(
|
||||
"--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
|
||||
)
|
||||
multimodal_group.add_argument(
|
||||
"--disable-mm-preprocessor-cache", action="store_true", deprecated=True
|
||||
)
|
||||
multimodal_group.add_argument(
|
||||
"--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
|
||||
)
|
||||
@ -1192,62 +1177,20 @@ class EngineArgs:
|
||||
if is_gguf(self.model):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
|
||||
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
|
||||
# doesn't affect the user process.
|
||||
if self.seed is None:
|
||||
logger.warning_once(
|
||||
"`seed=None` is equivalent to `seed=0` in V1 Engine. "
|
||||
"You will no longer be allowed to pass `None` in v0.13.",
|
||||
scope="local",
|
||||
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.warning(
|
||||
"The global random seed is set to %d. Since "
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
|
||||
"affect the random state of the Python process that "
|
||||
"launched vLLM.",
|
||||
self.seed,
|
||||
)
|
||||
|
||||
self.seed = 0
|
||||
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
logger.warning(
|
||||
"The global random seed is set to %d. Since "
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
|
||||
"affect the random state of the Python process that "
|
||||
"launched vLLM.",
|
||||
self.seed,
|
||||
)
|
||||
|
||||
if self.disable_mm_preprocessor_cache:
|
||||
logger.warning_once(
|
||||
"`--disable-mm-preprocessor-cache` is deprecated "
|
||||
"and will be removed in v0.13. "
|
||||
"Please use `--mm-processor-cache-gb 0` instead.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
self.mm_processor_cache_gb = 0
|
||||
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
|
||||
logger.warning_once(
|
||||
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
|
||||
"and will be removed in v0.13. "
|
||||
"Please use `--mm-processor-cache-gb %d` instead.",
|
||||
envs.VLLM_MM_INPUT_CACHE_GIB,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
|
||||
|
||||
if self.enable_multimodal_encoder_data_parallel:
|
||||
logger.warning_once(
|
||||
"--enable-multimodal-encoder-data-parallel` is deprecated "
|
||||
"and will be removed in v0.13. "
|
||||
"Please use `--mm-encoder-tp-mode data` instead.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
self.mm_encoder_tp_mode = "data"
|
||||
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
hf_config_path=self.hf_config_path,
|
||||
runner=self.runner,
|
||||
convert=self.convert,
|
||||
task=self.task,
|
||||
tokenizer=self.tokenizer,
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
|
||||
@ -198,7 +198,7 @@ class LLM:
|
||||
quantization: QuantizationMethods | None = None,
|
||||
revision: str | None = None,
|
||||
tokenizer_revision: str | None = None,
|
||||
seed: int | None = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
|
||||
@ -72,7 +72,6 @@ if TYPE_CHECKING:
|
||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
VLLM_MEDIA_CONNECTOR: str = "http"
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
VLLM_MAIN_CUDA_VERSION: str = "12.9"
|
||||
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
|
||||
@ -786,9 +785,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# imported at runtime.
|
||||
# If a non-existing backend is used, an AssertionError will be thrown.
|
||||
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
|
||||
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
|
||||
# Default is 4 GiB per API process + 4 GiB per engine core process
|
||||
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
||||
# Path to the XLA persistent cache directory.
|
||||
# Only used for XLA devices such as TPUs.
|
||||
"VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser(
|
||||
@ -1681,7 +1677,6 @@ def compile_factors() -> dict[str, object]:
|
||||
"VLLM_MEDIA_CONNECTOR",
|
||||
"VLLM_ASSETS_CACHE",
|
||||
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
|
||||
"VLLM_MM_INPUT_CACHE_GIB",
|
||||
"VLLM_WORKER_MULTIPROC_METHOD",
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING",
|
||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user