mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 10:47:05 +08:00
[TPU] Implement multi-step scheduling (#8489)
This commit is contained in:
parent
47790f3e32
commit
50e9ec41fc
@ -379,7 +379,7 @@ class ModelConfig:
|
|||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.enforce_eager:
|
if device_config.device_type == "cuda" and self.enforce_eager:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"To see benefits of async output processing, enable CUDA "
|
"To see benefits of async output processing, enable CUDA "
|
||||||
"graph. Since, enforce-eager is enabled, async output "
|
"graph. Since, enforce-eager is enabled, async output "
|
||||||
|
|||||||
@ -68,8 +68,12 @@ class RayTPUExecutor(TPUExecutor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert self.speculative_config is None
|
assert self.speculative_config is None
|
||||||
worker_module_name = "vllm.worker.tpu_worker"
|
if self.scheduler_config.is_multi_step:
|
||||||
worker_class_name = "TPUWorker"
|
worker_module_name = "vllm.worker.multi_step_tpu_worker"
|
||||||
|
worker_class_name = "MultiStepTPUWorker"
|
||||||
|
else:
|
||||||
|
worker_module_name = "vllm.worker.tpu_worker"
|
||||||
|
worker_class_name = "TPUWorker"
|
||||||
|
|
||||||
# GKE does not fetch environment information from metadata server
|
# GKE does not fetch environment information from metadata server
|
||||||
# and instead sets these from within the Ray process. Therefore we
|
# and instead sets these from within the Ray process. Therefore we
|
||||||
|
|||||||
@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
distributed_init_method: Optional[str] = None,
|
distributed_init_method: Optional[str] = None,
|
||||||
):
|
):
|
||||||
from vllm.worker.tpu_worker import TPUWorker
|
if self.scheduler_config.is_multi_step:
|
||||||
|
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
|
||||||
|
worker = MultiStepTPUWorker(**self._get_worker_kwargs(
|
||||||
|
local_rank, rank, distributed_init_method))
|
||||||
|
return worker
|
||||||
|
else:
|
||||||
|
from vllm.worker.tpu_worker import TPUWorker
|
||||||
|
|
||||||
worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
|
worker = TPUWorker(**self._get_worker_kwargs(
|
||||||
distributed_init_method))
|
local_rank, rank, distributed_init_method))
|
||||||
return worker
|
return worker
|
||||||
|
|
||||||
def initialize_cache(
|
def initialize_cache(
|
||||||
self,
|
self,
|
||||||
|
|||||||
105
vllm/worker/multi_step_tpu_worker.py
Normal file
105
vllm/worker/multi_step_tpu_worker.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.worker.tpu_model_runner import ModelInputForTPU
|
||||||
|
from vllm.worker.tpu_worker import TPUWorker
|
||||||
|
from vllm.worker.worker_base import WorkerInput
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepTPUWorker(TPUWorker):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.cached_model_input: Optional[ModelInputForTPU] = None
|
||||||
|
|
||||||
|
def _get_driver_input_and_broadcast(
|
||||||
|
self, execute_model_req: ExecuteModelRequest
|
||||||
|
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
|
||||||
|
assert self.is_driver_worker
|
||||||
|
assert execute_model_req.virtual_engine == 0
|
||||||
|
|
||||||
|
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||||
|
is_last_step = execute_model_req.is_last_step
|
||||||
|
if is_first_multi_step:
|
||||||
|
worker_input: WorkerInput = self.prepare_worker_input(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
worker_input = dataclasses.replace(
|
||||||
|
worker_input,
|
||||||
|
num_steps=execute_model_req.num_lookahead_slots + 1)
|
||||||
|
model_input: ModelInputForTPU = (
|
||||||
|
self.model_runner.prepare_model_input(
|
||||||
|
execute_model_req.seq_group_metadata_list,
|
||||||
|
execute_model_req.virtual_engine,
|
||||||
|
execute_model_req.finished_requests_ids))
|
||||||
|
|
||||||
|
if execute_model_req.async_callback:
|
||||||
|
model_input = dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
async_callback=execute_model_req.async_callback)
|
||||||
|
else:
|
||||||
|
assert self.cached_model_input is not None
|
||||||
|
model_input = self.cached_model_input
|
||||||
|
worker_input = WorkerInput()
|
||||||
|
model_input = dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
is_first_multi_step=is_first_multi_step,
|
||||||
|
is_last_step=is_last_step)
|
||||||
|
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
if is_first_multi_step:
|
||||||
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
|
broadcast_data.update(
|
||||||
|
model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
else:
|
||||||
|
broadcast_data = {
|
||||||
|
"is_first_multi_step": is_first_multi_step,
|
||||||
|
"is_last_step": is_last_step,
|
||||||
|
}
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
|
# Retuning empty dict here to keep this compatible with
|
||||||
|
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||||
|
return model_input, worker_input, {}
|
||||||
|
|
||||||
|
def prepare_input(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
|
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
|
||||||
|
torch.Tensor]]]:
|
||||||
|
if self.is_driver_worker:
|
||||||
|
if execute_model_req is None:
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
broadcast_tensor_dict({}, src=0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
||||||
|
execute_model_req)
|
||||||
|
if model_input.is_first_multi_step:
|
||||||
|
self.cached_model_input = model_input
|
||||||
|
return model_input, worker_input, {}
|
||||||
|
else:
|
||||||
|
broadcast_data = broadcast_tensor_dict(src=0)
|
||||||
|
if not broadcast_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(broadcast_data) == 2:
|
||||||
|
assert self.cached_model_input is not None
|
||||||
|
self.cached_model_input = dataclasses.replace(
|
||||||
|
self.cached_model_input,
|
||||||
|
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
||||||
|
is_last_step=broadcast_data["is_last_step"])
|
||||||
|
empty_worker_input = WorkerInput()
|
||||||
|
return self.cached_model_input, empty_worker_input, {}
|
||||||
|
|
||||||
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||||
|
broadcast_data)
|
||||||
|
model_input = (
|
||||||
|
self.model_runner.
|
||||||
|
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||||
|
self.cached_model_input = model_input
|
||||||
|
return model_input, worker_input, {}
|
||||||
@ -51,6 +51,8 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
num_samples: int
|
num_samples: int
|
||||||
best_of: List[int]
|
best_of: List[int]
|
||||||
seq_groups: List[List[int]]
|
seq_groups: List[List[int]]
|
||||||
|
is_first_multi_step: bool = True
|
||||||
|
is_last_step: bool = True
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
@ -65,6 +67,8 @@ class ModelInputForTPU(ModelRunnerInputBase):
|
|||||||
"num_samples": self.num_samples,
|
"num_samples": self.num_samples,
|
||||||
"best_of": self.best_of,
|
"best_of": self.best_of,
|
||||||
"seq_groups": self.seq_groups,
|
"seq_groups": self.seq_groups,
|
||||||
|
"is_first_multi_step": self.is_first_multi_step,
|
||||||
|
"is_last_step": self.is_last_step,
|
||||||
"virtual_engine": self.virtual_engine,
|
"virtual_engine": self.virtual_engine,
|
||||||
}
|
}
|
||||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
@ -118,6 +122,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
self.block_size,
|
self.block_size,
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
self.cached_step_outputs: List[torch.Tensor] = []
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
@ -518,97 +523,159 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
assert intermediate_tensors is None
|
assert intermediate_tensors is None
|
||||||
if num_steps > 1:
|
if not model_input.is_first_multi_step:
|
||||||
raise ValueError(
|
if not model_input.is_last_step:
|
||||||
"TPUModelRunner does not support multi-step execution.")
|
return []
|
||||||
|
|
||||||
def _execute_model(*args):
|
use_async_out_proc = model_input.async_callback is not None
|
||||||
"""Move input args from CPU to device and execute the model."""
|
sampler_outputs = []
|
||||||
|
num_outputs = len(self.cached_step_outputs)
|
||||||
|
for i in range(num_outputs):
|
||||||
|
next_token_ids = self.cached_step_outputs.pop(0)
|
||||||
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
sampler_output = _make_decode_output(next_token_ids,
|
||||||
|
model_input.seq_groups)
|
||||||
|
sampler_outputs.append(sampler_output)
|
||||||
|
|
||||||
new_args = []
|
if i < num_outputs - 1 and use_async_out_proc:
|
||||||
for arg in args:
|
assert model_input.async_callback is not None
|
||||||
if isinstance(arg, torch.Tensor):
|
ctx = model_input.async_callback.keywords[ # type: ignore
|
||||||
arg = arg.to(self.device)
|
"ctx"]
|
||||||
elif isinstance(arg, AttentionMetadata):
|
ctx.append_output(
|
||||||
arg.slot_mapping = arg.slot_mapping.to(self.device)
|
outputs=[sampler_output],
|
||||||
if getattr(arg, "block_tables", None) is not None:
|
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||||
arg.block_tables = arg.block_tables.to(self.device)
|
scheduler_outputs=ctx.scheduler_outputs,
|
||||||
if getattr(arg, "context_lens", None) is not None:
|
is_async=False,
|
||||||
arg.context_lens = arg.context_lens.to(self.device)
|
is_last_step=False)
|
||||||
new_args.append(arg)
|
model_input.async_callback()
|
||||||
return self.model(*new_args, is_prompt=is_prompt)
|
if use_async_out_proc:
|
||||||
|
return [sampler_outputs[-1]]
|
||||||
|
else:
|
||||||
|
return sampler_outputs
|
||||||
|
|
||||||
num_prefills = model_input.attn_metadata.num_prefills
|
is_prompt = model_input.attn_metadata.num_prefills > 0
|
||||||
is_prompt = num_prefills > 0
|
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
|
assert num_steps == 1
|
||||||
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
||||||
# ragged inputs, we split the prompts into different batches and
|
# ragged inputs, we split the prompts into different batches and
|
||||||
# process them separately. This is a temporary hack that should be
|
# process them separately. This is a temporary hack that should be
|
||||||
# optimized by using SplashAttention.
|
# optimized by using SplashAttention.
|
||||||
next_token_ids = []
|
|
||||||
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
||||||
batch_size = model_input.input_lens.shape[0]
|
batch_size = model_input.input_lens.shape[0]
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
next_token_ids = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# Get the actual prefill_len.
|
# Get the actual prefill_len.
|
||||||
prefill_len = model_input.input_lens[i:i + 1].item()
|
prefill_len = model_input.input_lens[i:i + 1].item()
|
||||||
prefill_len = _get_padded_prefill_len(prefill_len)
|
prefill_len = _get_padded_prefill_len(prefill_len)
|
||||||
end_idx = start_idx + prefill_len
|
end_idx = start_idx + prefill_len
|
||||||
|
|
||||||
model_input.attn_metadata.slot_mapping = orig_slot_mapping[
|
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
|
||||||
None, start_idx:end_idx]
|
self.device)
|
||||||
model_input.attn_metadata.num_prefills = 1
|
position_ids = model_input.position_ids[None,
|
||||||
output_token_ids = _execute_model(
|
start_idx:end_idx].to(
|
||||||
model_input.token_ids[None, start_idx:end_idx],
|
self.device)
|
||||||
model_input.position_ids[None, start_idx:end_idx],
|
attn_metadata = model_input.attn_metadata
|
||||||
model_input.attn_metadata, model_input.input_lens[i:i + 1],
|
attn_metadata.num_prefills = 1
|
||||||
model_input.t[i:i + 1], model_input.p[i:i + 1],
|
attn_metadata.slot_mapping = orig_slot_mapping[
|
||||||
model_input.num_samples, kv_caches)
|
None, start_idx:end_idx].to(self.device)
|
||||||
if i == 0 and model_input.async_callback is not None:
|
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
||||||
model_input.async_callback()
|
t = model_input.t[i:i + 1].to(self.device)
|
||||||
# Retrieve the outputs to CPU.
|
p = model_input.p[i:i + 1].to(self.device)
|
||||||
next_token_ids += output_token_ids.cpu().tolist()
|
output_token_ids = self.model(token_ids,
|
||||||
|
position_ids,
|
||||||
|
attn_metadata,
|
||||||
|
input_lens,
|
||||||
|
t,
|
||||||
|
p,
|
||||||
|
model_input.num_samples,
|
||||||
|
kv_caches,
|
||||||
|
is_prompt=True)
|
||||||
|
next_token_ids.append(output_token_ids[0])
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
else:
|
|
||||||
# Execute the model.
|
|
||||||
output_token_ids = _execute_model(
|
|
||||||
model_input.token_ids, model_input.position_ids,
|
|
||||||
model_input.attn_metadata, model_input.input_lens,
|
|
||||||
model_input.t, model_input.p, model_input.num_samples,
|
|
||||||
kv_caches)
|
|
||||||
if model_input.async_callback is not None:
|
if model_input.async_callback is not None:
|
||||||
model_input.async_callback()
|
model_input.async_callback()
|
||||||
# Retrieve the outputs to CPU.
|
# Retrieve the outputs to CPU.
|
||||||
next_token_ids = output_token_ids.cpu().tolist()
|
next_token_ids = [
|
||||||
|
output_token_ids.cpu().tolist()
|
||||||
|
for output_token_ids in next_token_ids
|
||||||
|
]
|
||||||
|
|
||||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||||
# does not support the advanced sampling parameters such as logprobs.
|
# does not support advanced sampling parameters such as logprobs.
|
||||||
zero_logprob = Logprob(0.0)
|
zero_logprob = Logprob(0.0)
|
||||||
batch_idx = 0
|
sampler_outputs = []
|
||||||
sampler_outputs = []
|
for i, seq_group in enumerate(model_input.seq_groups):
|
||||||
for seq_group in model_input.seq_groups:
|
seq_ids = seq_group
|
||||||
seq_ids = seq_group
|
|
||||||
seq_outputs = []
|
|
||||||
if is_prompt:
|
|
||||||
assert len(seq_ids) == 1
|
assert len(seq_ids) == 1
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
for i in range(model_input.best_of[batch_idx]):
|
seq_outputs = []
|
||||||
next_token_id = next_token_ids[batch_idx][i]
|
for j in range(model_input.best_of[i]):
|
||||||
|
next_token_id = next_token_ids[i][j]
|
||||||
seq_outputs.append(
|
seq_outputs.append(
|
||||||
SequenceOutput(seq_id, next_token_id,
|
SequenceOutput(seq_id, next_token_id,
|
||||||
{next_token_id: zero_logprob}))
|
{next_token_id: zero_logprob}))
|
||||||
batch_idx += 1
|
sampler_outputs.append(
|
||||||
else:
|
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||||
for seq_id in seq_ids:
|
return [SamplerOutput(sampler_outputs)]
|
||||||
next_token_id = next_token_ids[batch_idx]
|
else:
|
||||||
seq_outputs.append(
|
token_ids = model_input.token_ids.to(self.device)
|
||||||
SequenceOutput(seq_id, next_token_id,
|
position_ids = model_input.position_ids.to(self.device)
|
||||||
{next_token_id: zero_logprob}))
|
attn_metadata = model_input.attn_metadata
|
||||||
batch_idx += 1
|
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
|
||||||
sampler_outputs.append(
|
self.device)
|
||||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
attn_metadata.block_tables = attn_metadata.block_tables.to(
|
||||||
return [SamplerOutput(sampler_outputs)]
|
self.device)
|
||||||
|
attn_metadata.context_lens = attn_metadata.context_lens.to(
|
||||||
|
self.device)
|
||||||
|
t = model_input.t.to(self.device)
|
||||||
|
p = model_input.p.to(self.device)
|
||||||
|
input_lens = model_input.input_lens.to(self.device)
|
||||||
|
for i in range(num_steps):
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
output_token_ids = self.model(token_ids,
|
||||||
|
position_ids,
|
||||||
|
attn_metadata,
|
||||||
|
input_lens,
|
||||||
|
t,
|
||||||
|
p,
|
||||||
|
model_input.num_samples,
|
||||||
|
kv_caches,
|
||||||
|
is_prompt=False)
|
||||||
|
self.cached_step_outputs.append(output_token_ids)
|
||||||
|
|
||||||
|
if i < num_steps - 1:
|
||||||
|
# Prepare the inputs for the next step.
|
||||||
|
token_ids = output_token_ids.unsqueeze(dim=1).int()
|
||||||
|
position_ids = position_ids + 1
|
||||||
|
attn_metadata.context_lens = attn_metadata.context_lens + 1
|
||||||
|
|
||||||
|
block_tables = attn_metadata.block_tables
|
||||||
|
block_number = block_tables.gather(
|
||||||
|
1,
|
||||||
|
position_ids.long() // self.block_size)
|
||||||
|
block_offset = position_ids % self.block_size
|
||||||
|
|
||||||
|
is_padding = slot_mapping == _PAD_SLOT_ID
|
||||||
|
slot_mapping = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping = slot_mapping.long()
|
||||||
|
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
|
||||||
|
slot_mapping)
|
||||||
|
attn_metadata.slot_mapping = slot_mapping
|
||||||
|
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
|
|
||||||
|
if num_steps > 1:
|
||||||
|
return []
|
||||||
|
# Retrieve the outputs to CPU.
|
||||||
|
next_token_ids = self.cached_step_outputs.pop(0)
|
||||||
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
sampler_output = _make_decode_output(next_token_ids,
|
||||||
|
model_input.seq_groups)
|
||||||
|
return [sampler_output]
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||||
@ -756,3 +823,24 @@ def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
|||||||
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
||||||
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def _make_decode_output(
|
||||||
|
next_token_ids: List[int],
|
||||||
|
seq_groups: List[List[int]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
zero_logprob = Logprob(0.0)
|
||||||
|
sampler_outputs = []
|
||||||
|
batch_idx = 0
|
||||||
|
for seq_group in seq_groups:
|
||||||
|
seq_ids = seq_group
|
||||||
|
seq_outputs = []
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
next_token_id = next_token_ids[batch_idx]
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: zero_logprob}))
|
||||||
|
batch_idx += 1
|
||||||
|
sampler_outputs.append(CompletionSequenceGroupOutput(
|
||||||
|
seq_outputs, None))
|
||||||
|
return SamplerOutput(sampler_outputs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user