[Hardware] [Intel GPU] refactor xpu worker/executor (#7686)

This commit is contained in:
Kunshang Ji 2024-08-21 00:54:10 +08:00 committed by GitHub
parent aae6927be0
commit c42590f97a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 28 deletions

View File

@ -1,16 +1,16 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -30,6 +30,7 @@ class XPUExecutor(GPUExecutor):
lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
@ -46,32 +47,23 @@ class XPUExecutor(GPUExecutor):
self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU.
self._init_executor()
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
else:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output

View File

@ -137,7 +137,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,

View File

@ -9,8 +9,8 @@ import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
@ -50,6 +50,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
assert device_config.device_type == "xpu"
assert is_xpu()
@ -67,8 +68,10 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.observability_config = observability_config
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
self.multimodal_config = multimodal_config
@ -183,7 +186,11 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
"sockets")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,