mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 13:41:54 +08:00
[2/N] executor pass the complete config to worker/modelrunner (#9938)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
1d4cfe2be1
commit
e893795443
@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init):
|
||||
enable_lora=True)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = ModelRunner(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
vllm_config=engine_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
model_runner.load_model()
|
||||
|
||||
@ -4,7 +4,8 @@ import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig)
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.worker.worker import Worker
|
||||
@ -12,7 +13,7 @@ from vllm.worker.worker import Worker
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "0"})
|
||||
def test_worker_apply_lora(sql_lora_files):
|
||||
worker = Worker(
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
task="auto",
|
||||
@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
gpu_memory_utilization=1.,
|
||||
swap_space=0,
|
||||
cache_dtype="auto"),
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
||||
max_loras=32),
|
||||
)
|
||||
worker = Worker(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
|
||||
)
|
||||
worker.init_device()
|
||||
|
||||
@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T],
|
||||
get_ip(), get_open_port())
|
||||
|
||||
worker = cls(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args,
|
||||
engine_args = EngineArgs(model, *args, **kwargs)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = EncoderDecoderModelRunner(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
prompt_adapter_config=engine_config.prompt_adapter_config,
|
||||
vllm_config=engine_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
||||
@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
engine_args = EngineArgs(model, *args, **kwargs)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = ModelRunner(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
prompt_adapter_config=engine_config.prompt_adapter_config,
|
||||
observability_config=engine_config.observability_config,
|
||||
vllm_config=engine_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
||||
@ -24,12 +24,7 @@ def test_gpu_memory_profiling():
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
@ -19,12 +19,7 @@ def test_swap() -> None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
@ -1941,9 +1941,9 @@ class ObservabilityConfig:
|
||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EngineConfig:
|
||||
"""Dataclass which contains all engine-related configuration. This
|
||||
@dataclass
|
||||
class VllmConfig:
|
||||
"""Dataclass which contains all vllm-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
@ -1953,11 +1953,11 @@ class EngineConfig:
|
||||
scheduler_config: SchedulerConfig
|
||||
device_config: DeviceConfig
|
||||
load_config: LoadConfig
|
||||
lora_config: Optional[LoRAConfig]
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
observability_config: Optional[ObservabilityConfig]
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig]
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
@ -1975,9 +1975,3 @@ class EngineConfig:
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def to_dict(self):
|
||||
"""Return the configs as a dictionary, for use in **kwargs.
|
||||
"""
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
@ -9,10 +9,11 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
||||
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
|
||||
DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
|
||||
VllmConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
@ -955,7 +956,7 @@ class EngineArgs:
|
||||
ignore_patterns=self.ignore_patterns,
|
||||
)
|
||||
|
||||
def create_engine_config(self) -> EngineConfig:
|
||||
def create_engine_config(self) -> VllmConfig:
|
||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||
if check_gguf_file(self.model):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
@ -1167,7 +1168,7 @@ class EngineArgs:
|
||||
or "all" in detailed_trace_modules,
|
||||
)
|
||||
|
||||
return EngineConfig(
|
||||
return VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
|
||||
@ -7,8 +7,8 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
|
||||
from weakref import ReferenceType
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
@ -604,7 +604,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(
|
||||
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
|
||||
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
|
||||
distributed_executor_backend = (
|
||||
engine_config.parallel_config.distributed_executor_backend)
|
||||
if isinstance(distributed_executor_backend, type):
|
||||
@ -663,7 +663,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[EngineConfig] = None,
|
||||
engine_config: Optional[VllmConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
|
||||
@ -13,8 +13,9 @@ import torch
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||
SchedulerOutputs)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -219,7 +220,7 @@ class LLMEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: EngineConfig,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@ -500,7 +501,7 @@ class LLMEngine:
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls,
|
||||
engine_config: EngineConfig) -> Type[ExecutorBase]:
|
||||
engine_config: VllmConfig) -> Type[ExecutorBase]:
|
||||
distributed_executor_backend = (
|
||||
engine_config.parallel_config.distributed_executor_backend)
|
||||
# Initialize the cluster and specify the executor class.
|
||||
|
||||
@ -13,7 +13,7 @@ from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# yapf conflicts with isort for this block
|
||||
@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
every N seconds, confirming the engine is healthy
|
||||
"""
|
||||
|
||||
def __init__(self, ipc_path: str, engine_config: EngineConfig,
|
||||
def __init__(self, ipc_path: str, engine_config: VllmConfig,
|
||||
engine_pid: int):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@ -138,18 +138,11 @@ class CPUExecutor(ExecutorBase):
|
||||
assert self.distributed_init_method is not None
|
||||
|
||||
kwargs = dict(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
wrapper.init_worker(**kwargs)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import EngineConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -20,7 +20,7 @@ class ExecutorBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: EngineConfig,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
@ -49,21 +49,12 @@ class GPUExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
speculative_config=self.speculative_config,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
|
||||
def _get_worker_module_and_class(
|
||||
|
||||
@ -29,11 +29,7 @@ class NeuronExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = NeuronWorker(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method)
|
||||
|
||||
@ -48,16 +48,10 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = OpenVINOWorker(
|
||||
ov_core=self.ov_core,
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@ -44,12 +44,7 @@ class TPUExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
||||
@ -17,9 +17,6 @@ except (ModuleNotFoundError, ImportError) as err:
|
||||
"Draft model speculative decoding currently only supports"
|
||||
"CUDA and ROCm flash attention backend.") from err
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalInputs
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
if return_hidden_states:
|
||||
def __init__(self, *args, **kwargs):
|
||||
if kwargs.get("return_hidden_states"):
|
||||
raise ValueError(
|
||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
observability_config=observability_config,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = kwargs["local_rank"]
|
||||
self.vocab_size = kwargs["model_config"].get_vocab_size()
|
||||
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
"""Helper method that is the entrypoint for Executors which use
|
||||
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
|
||||
"""
|
||||
assert "speculative_config" in kwargs
|
||||
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
|
||||
vllm_config: VllmConfig = kwargs.get("vllm_config")
|
||||
speculative_config: SpeculativeConfig = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
@ -58,14 +59,16 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
target_worker.model_runner.disable_logprobs =\
|
||||
speculative_config.disable_logprobs
|
||||
|
||||
draft_worker_config = copy.deepcopy(vllm_config)
|
||||
draft_worker_config.model_config = speculative_config.draft_model_config
|
||||
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
|
||||
# TODO allow draft-model specific load config.
|
||||
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
model_config=speculative_config.draft_model_config,
|
||||
parallel_config=speculative_config.draft_parallel_config,
|
||||
vllm_config=draft_worker_config,
|
||||
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
|
||||
# TODO allow draft-model specific load config.
|
||||
#load_config=load_config,
|
||||
)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
@ -134,29 +137,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'vllm_config'].parallel_config
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
else:
|
||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||
'parallel_config']
|
||||
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||
|
||||
if draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||
if draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
elif draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "medusa":
|
||||
elif draft_model_config.hf_config.model_type == "medusa":
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1:
|
||||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
if draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "eagle":
|
||||
if draft_model_config.hf_config.model_type == "eagle":
|
||||
raise NotImplementedError(
|
||||
"EAGLE does not support TP > 1 yet")
|
||||
|
||||
@ -190,8 +191,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"MQA is only available with flash attn backend.")
|
||||
|
||||
if "model_config" in draft_worker_kwargs and \
|
||||
draft_worker_kwargs["model_config"].max_model_len < \
|
||||
if draft_model_config and \
|
||||
draft_model_config.max_model_len < \
|
||||
scorer_worker.model_config.max_model_len:
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner):
|
||||
requested or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
self.disable_logprobs = True
|
||||
super().__init__(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=lora_config,
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
observability_config=observability_config,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
|
||||
@ -2,8 +2,9 @@ import time
|
||||
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
|
||||
Union)
|
||||
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
@ -32,7 +33,7 @@ class LLMEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: EngineConfig,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@ -477,7 +478,7 @@ class LLMEngine:
|
||||
return self.lora_config
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, engine_config: EngineConfig):
|
||||
def _get_executor_cls(cls, engine_config: VllmConfig):
|
||||
return GPUExecutor
|
||||
|
||||
def is_tracing_enabled(self) -> bool:
|
||||
|
||||
@ -56,19 +56,10 @@ class GPUExecutor:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return Worker(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
speculative_config=self.speculative_config,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
|
||||
@ -7,9 +7,7 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -33,26 +31,25 @@ class GPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
# TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
scheduler_config = self.scheduler_config
|
||||
parallel_config = self.parallel_config
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
@ -6,10 +6,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
@ -30,48 +27,35 @@ class Worker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
|
||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.speculative_config = speculative_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
self.model_runner = GPUModelRunner(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
self.model_runner = GPUModelRunner(vllm_config)
|
||||
|
||||
def initialize(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
|
||||
@ -8,9 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, PromptAdapterConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@ -412,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
# Currently, CPU worker doesn't support chunked prefill.
|
||||
assert self.scheduler_config.chunked_prefill_enabled is False
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.load_config = load_config
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.device = self.device_config.device
|
||||
|
||||
@ -6,9 +6,8 @@ import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, PromptAdapterConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, VllmConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -18,7 +17,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerInput)
|
||||
LoraNotSupportedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
@ -166,15 +154,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
if self._is_encoder_decoder_model():
|
||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||
self.model_runner: CPUModelRunner = ModelRunnerClass(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
||||
@ -3,9 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -36,29 +34,13 @@ class EmbeddingModelRunner(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
super().__init__(model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
lora_config=lora_config,
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
observability_config=observability_config)
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
||||
@ -11,9 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend,
|
||||
global_force_attn_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
@ -85,17 +83,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
@ -107,15 +97,10 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
models) but these arguments are present here for compatibility with
|
||||
the base-class constructor.
|
||||
'''
|
||||
self._maybe_force_supported_attention_backend(model_config)
|
||||
self._maybe_force_supported_attention_backend(vllm_config.model_config)
|
||||
|
||||
super().__init__(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
lora_config=None,
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
@ -20,9 +20,7 @@ from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
@ -955,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.observability_config = observability_config
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
|
||||
import torch
|
||||
from torch import is_tensor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
@ -220,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
ModelRunnerInputBase subclass.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
# Map of request_id -> generator used for seeded random sampling
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
|
||||
@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
# mypy: enable-error-code=type-var
|
||||
|
||||
def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Check attention backend support.
|
||||
|
||||
@ -27,17 +27,9 @@ class MultiStepWorker(Worker):
|
||||
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
||||
self.model_runner = MultiStepModelRunner(
|
||||
base_model_runner,
|
||||
base_model_runner.model_config,
|
||||
base_model_runner.parallel_config,
|
||||
base_model_runner.scheduler_config,
|
||||
base_model_runner.device_config,
|
||||
base_model_runner.cache_config,
|
||||
load_config=base_model_runner.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vllm_config=base_model_runner.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=base_model_runner.is_driver_worker,
|
||||
prompt_adapter_config=base_model_runner.prompt_adapter_config,
|
||||
observability_config=base_model_runner.observability_config,
|
||||
)
|
||||
|
||||
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
|
||||
@ -7,8 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers_neuronx.config import GenerationConfig
|
||||
|
||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
vllm_config: VllmConfig,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
model_config = self.model_config
|
||||
if model_config is not None and model_config.get_sliding_window():
|
||||
logger.warning("Sliding window is not supported on Neuron. "
|
||||
"The model will run without sliding window.")
|
||||
self.device_config = (device_config
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
|
||||
@ -4,15 +4,15 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerInput)
|
||||
LoraNotSupportedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
|
||||
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
@ -44,7 +36,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
init_cached_hf_modules()
|
||||
|
||||
self.model_runner: NeuronModelRunner = NeuronModelRunner(
|
||||
model_config, parallel_config, scheduler_config, device_config)
|
||||
vllm_config=vllm_config)
|
||||
self.is_driver_worker = True
|
||||
|
||||
def init_device(self) -> None:
|
||||
|
||||
@ -7,9 +7,7 @@ from torch import nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -17,6 +15,7 @@ from vllm.model_executor.model_loader.openvino import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalInputs, MultiModalPlaceholderMap)
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -39,33 +38,21 @@ class ModelInput(NamedTuple):
|
||||
multi_modal_kwargs={})
|
||||
|
||||
|
||||
class OpenVINOModelRunner:
|
||||
class OpenVINOModelRunner(ModelRunnerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.ov_core = ov_core
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.load_config = load_config
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
cache_config = self.cache_config
|
||||
model_config = self.model_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.device = self.device_config.device
|
||||
@ -369,3 +356,9 @@ class OpenVINOModelRunner:
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_model_input(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -7,9 +7,8 @@ import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, VllmConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
@ -22,7 +21,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.ov_core = ov_core
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
WorkerBase.__init__(self, vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
@ -250,14 +235,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
init_cached_hf_modules()
|
||||
self.model_runner = OpenVINOModelRunner(
|
||||
self.ov_core,
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
@ -12,8 +12,7 @@ import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
@ -6,8 +6,7 @@ import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -16,7 +15,8 @@ from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.worker.tpu_model_runner import TPUModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerInput)
|
||||
LoraNotSupportedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
@ -56,13 +46,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.model_runner: TPUModelRunner = TPUModelRunner(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
is_driver_worker=is_driver_worker)
|
||||
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
|
||||
|
||||
def init_device(self) -> None:
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
|
||||
@ -7,10 +7,7 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
@ -27,7 +24,8 @@ from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -42,46 +40,31 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
WorkerBase.__init__(self, vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if parallel_config and is_driver_worker:
|
||||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
if is_driver_worker:
|
||||
assert rank % self.parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
self.observability_config = observability_config
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
@ -97,17 +80,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
elif self._is_encoder_decoder_model():
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=self.lora_config,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
observability_config=observability_config,
|
||||
**speculative_args,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ObservabilityConfig
|
||||
from vllm.config import ObservabilityConfig, VllmConfig
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -29,6 +29,22 @@ class WorkerBase(ABC):
|
||||
communicate request metadata to other workers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
@abstractmethod
|
||||
def init_device(self) -> None:
|
||||
"""Initialize device state, such as loading the model or other on-device
|
||||
|
||||
@ -10,9 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
@ -363,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
if self.observability_config is not None:
|
||||
print(f"observability_config is {self.observability_config}")
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
|
||||
@ -8,10 +8,7 @@ import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -19,7 +16,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
device_config = self.device_config
|
||||
parallel_config = self.parallel_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.observability_config = observability_config
|
||||
if parallel_config and is_driver_worker:
|
||||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user