[1/N] pass the complete config from engine to executor (#9933)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-01 13:51:57 -07:00 committed by GitHub
parent 598b6d7b07
commit 18bd7587b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 64 additions and 136 deletions

View File

@ -680,7 +680,7 @@ class AsyncLLMEngine(EngineClient):
# Create the async LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,

View File

@ -13,11 +13,8 @@ import torch
from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
@ -222,17 +219,7 @@ class LLMEngine:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@ -240,6 +227,22 @@ class LLMEngine:
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
@ -340,18 +343,7 @@ class LLMEngine:
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.task != "embedding":
self._initialize_kv_caches()
@ -582,7 +574,7 @@ class LLMEngine:
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,

View File

@ -7,8 +7,6 @@ import cloudpickle
import zmq
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
@ -30,9 +28,6 @@ if VLLM_USE_V1:
else:
from vllm.engine.llm_engine import LLMEngine
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
@ -130,7 +125,7 @@ class MQLLMEngine:
return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,

View File

@ -1,10 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import EngineConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -23,27 +20,19 @@ class ExecutorBase(ABC):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
vllm_config: EngineConfig,
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self._init_executor()
@abstractmethod

View File

@ -2,10 +2,7 @@ from typing import Callable, List, Optional, Tuple, Type, Union
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import ModelConfig, ParallelConfig
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor):
uses_ray: bool = False
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
), "Speculative decoding not yet supported for XPU backend"
def _init_executor(self) -> None:
assert self.device_config.device_type == "xpu"
assert self.speculative_config is None, (
"Speculative decoding not yet supported for XPU backend")
model_config = _verify_and_get_model_config(model_config)
self.model_config = model_config
self.cache_config = cache_config
self.load_config = load_config
self.lora_config = lora_config
self.parallel_config = _verify_and_get_parallel_config(parallel_config)
self.scheduler_config = scheduler_config
self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU.
self._init_executor()
self.model_config = _verify_and_get_model_config(self.model_config)
GPUExecutor._init_executor(self)
def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:

View File

@ -2,11 +2,8 @@ import time
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
@ -35,17 +32,7 @@ class LLMEngine:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
vllm_config: EngineConfig,
executor_class: Type[GPUExecutor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@ -53,6 +40,22 @@ class LLMEngine:
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
# Override the configs for V1.
# FIXME
if usage_context == UsageContext.LLM_CLASS:
@ -112,18 +115,6 @@ class LLMEngine:
model_config.mm_processor_kwargs,
)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
assert not self.model_config.skip_tokenizer_init
@ -154,18 +145,7 @@ class LLMEngine:
# Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {}
self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
self.model_executor = executor_class(vllm_config=vllm_config)
assert self.model_config.task != "embedding"
self._initialize_kv_caches()
@ -203,7 +183,7 @@ class LLMEngine:
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,