mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 18:25:46 +08:00
123 lines
5.2 KiB
Python
123 lines
5.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
###############################################################################
|
|
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
|
|
###############################################################################
|
|
|
|
import dataclasses
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.distributed import broadcast_tensor_dict
|
|
from vllm.sequence import ExecuteModelRequest
|
|
from vllm.worker.hpu_model_runner import ModelInputForHPU
|
|
from vllm.worker.hpu_worker import HPUWorker
|
|
from vllm.worker.worker_base import WorkerInput
|
|
|
|
|
|
class MultiStepHPUWorker(HPUWorker):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.cached_model_input: Optional[ModelInputForHPU] = None
|
|
|
|
def _get_driver_input_and_broadcast(
|
|
self, execute_model_req: ExecuteModelRequest
|
|
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
|
|
"""
|
|
Get the driver input and broadcast it to other workers.
|
|
"""
|
|
assert self.is_driver_worker
|
|
assert execute_model_req.virtual_engine == 0
|
|
|
|
is_first_multi_step = execute_model_req.is_first_multi_step
|
|
is_last_step = execute_model_req.is_last_step
|
|
|
|
if is_first_multi_step:
|
|
# on first step we prepare the worker input and model input normally
|
|
worker_input: WorkerInput = self.prepare_worker_input(
|
|
execute_model_req=execute_model_req)
|
|
worker_input = dataclasses.replace(
|
|
worker_input,
|
|
num_steps=execute_model_req.num_lookahead_slots + 1)
|
|
model_input: ModelInputForHPU = (
|
|
self.model_runner.prepare_model_input(
|
|
execute_model_req.seq_group_metadata_list,
|
|
execute_model_req.virtual_engine,
|
|
execute_model_req.finished_requests_ids))
|
|
|
|
if execute_model_req.async_callback:
|
|
model_input = dataclasses.replace(
|
|
model_input,
|
|
async_callback=execute_model_req.async_callback)
|
|
else:
|
|
# on subsequent steps we reuse the worker input and model input
|
|
assert self.cached_model_input is not None
|
|
model_input = self.cached_model_input
|
|
worker_input = WorkerInput()
|
|
|
|
model_input = dataclasses.replace(
|
|
model_input,
|
|
is_first_multi_step=is_first_multi_step,
|
|
is_last_step=is_last_step)
|
|
|
|
if self.do_metadata_broadcast:
|
|
if is_first_multi_step:
|
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
|
broadcast_data.update(
|
|
model_input.as_broadcastable_tensor_dict())
|
|
broadcast_tensor_dict(broadcast_data, src=0)
|
|
else:
|
|
broadcast_data = {
|
|
"is_first_multi_step": is_first_multi_step,
|
|
"is_last_step": is_last_step,
|
|
}
|
|
broadcast_tensor_dict(broadcast_data, src=0)
|
|
|
|
# Returning empty dict here to keep this compatible with
|
|
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
|
return model_input, worker_input, {}
|
|
|
|
def prepare_input(
|
|
self,
|
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
|
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
|
|
torch.Tensor]]]:
|
|
if self.is_driver_worker:
|
|
if execute_model_req is None:
|
|
if self.do_metadata_broadcast:
|
|
# This signals that there's no more requests to process for
|
|
# now. All workers are running infinite loop with
|
|
# broadcast_tensor_dict, and it stops the loop when the
|
|
# driver broadcasts an empty input. Send an empty input to
|
|
# notify all other workers to stop their execution loop.
|
|
broadcast_tensor_dict({}, src=0)
|
|
return None
|
|
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
|
execute_model_req)
|
|
if model_input.is_first_multi_step:
|
|
self.cached_model_input = model_input
|
|
return model_input, worker_input, {}
|
|
else:
|
|
broadcast_data = broadcast_tensor_dict(src=0)
|
|
if not broadcast_data:
|
|
return None
|
|
|
|
if len(broadcast_data) == 2:
|
|
assert self.cached_model_input is not None
|
|
self.cached_model_input = dataclasses.replace(
|
|
self.cached_model_input,
|
|
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
|
is_last_step=broadcast_data["is_last_step"])
|
|
empty_worker_input = WorkerInput()
|
|
return self.cached_model_input, empty_worker_input, {}
|
|
|
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
|
broadcast_data)
|
|
model_input = (
|
|
self.model_runner.
|
|
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
|
self.cached_model_input = model_input
|
|
return model_input, worker_input, {}
|