mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-29 03:18:10 +08:00
[Hardware][Intel-Gaudi] Multi-step scheduling implementation for HPU (#12779)
Signed-off-by: Tomasz Zielinski <tomasz.zielinski@intel.com>
This commit is contained in:
parent
9e90c9f73f
commit
34b2cf3b33
@ -46,15 +46,15 @@ class HpuPlatform(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if scheduler_config.is_multi_step:
|
||||
raise NotImplementedError(
|
||||
"Multi-step execution is not implemented for HPU")
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
|
||||
|
||||
if vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not implemented for HPU")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import DeviceConfig, VllmConfig
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
from vllm.utils import (bind_kv_cache, is_pin_memory_available,
|
||||
make_tensor_with_pad)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -100,7 +103,10 @@ def subtuple(obj: object,
|
||||
if to_override is None:
|
||||
to_override = {}
|
||||
fields = set(to_copy) | set(to_override.keys())
|
||||
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
||||
if type(obj) is dict:
|
||||
values = {key: obj[key] for key in fields if key in obj}
|
||||
else:
|
||||
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
||||
if typename not in _TYPE_CACHE:
|
||||
_TYPE_CACHE[typename] = collections.namedtuple(typename,
|
||||
' '.join(fields))
|
||||
@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
|
||||
virtual_engine: int = 0
|
||||
lora_ids: Optional[List[int]] = None
|
||||
async_callback: Optional[Callable] = None
|
||||
is_first_multi_step: bool = True
|
||||
is_last_step: bool = True
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
|
||||
"batch_size_padded": self.batch_size_padded,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"lora_ids": self.lora_ids,
|
||||
"is_first_multi_step": self.is_first_multi_step,
|
||||
"is_last_step": self.is_last_step,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
return tensor_dict
|
||||
@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
self._set_gc_threshold()
|
||||
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
|
||||
|
||||
# For multi-step scheduling
|
||||
self.cached_step_outputs: List[torch.Tensor] = []
|
||||
|
||||
def _set_gc_threshold(self) -> None:
|
||||
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
|
||||
# for comprehensive description of gc generations.
|
||||
@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
output=None,
|
||||
) -> PrepareDecodeMetadata:
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
if output is None:
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
num_fully_occupied_blocks = position // self.block_size
|
||||
block_table = block_table[:num_fully_occupied_blocks + 1]
|
||||
|
||||
if len(block_table) == 0:
|
||||
block_number = _PAD_BLOCK_ID
|
||||
else:
|
||||
@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
if output is None:
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
else:
|
||||
real_batch_size = len(seq_group_metadata_list)
|
||||
input_tokens = output[:real_batch_size]
|
||||
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
profiler.start()
|
||||
for _ in range(times):
|
||||
inputs = self.prepare_model_input(seqs)
|
||||
self.execute_model(inputs, None, warmup_mode=True)
|
||||
is_single_step = \
|
||||
self.vllm_config.scheduler_config.num_scheduler_steps == 1
|
||||
if is_prompt or is_single_step:
|
||||
self.execute_model(inputs, None, warmup_mode=True)
|
||||
else: # decode with multi-step
|
||||
inputs = dataclasses.replace(inputs,
|
||||
is_first_multi_step=True,
|
||||
is_last_step=False)
|
||||
self.execute_model(inputs,
|
||||
None,
|
||||
warmup_mode=True,
|
||||
num_steps=2,
|
||||
seqs=seqs)
|
||||
inputs = dataclasses.replace(inputs,
|
||||
is_first_multi_step=False,
|
||||
is_last_step=True)
|
||||
self.execute_model(inputs,
|
||||
None,
|
||||
warmup_mode=True,
|
||||
num_steps=2,
|
||||
seqs=seqs)
|
||||
torch.hpu.synchronize()
|
||||
if profiler:
|
||||
profiler.step()
|
||||
@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
warmup_mode=False,
|
||||
seqs=None,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"num_steps > 1 is not supported in HPUModelRunner")
|
||||
if not model_input.is_first_multi_step:
|
||||
if not model_input.is_last_step:
|
||||
# not first or last multi-step
|
||||
return []
|
||||
# last multi-step
|
||||
output = self._decode_sampler_outputs(
|
||||
model_input) if self.is_driver_worker else []
|
||||
torch.hpu.synchronize()
|
||||
if model_input.is_first_multi_step:
|
||||
# first multi-step
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
real_batch_size = model_input.real_batch_size
|
||||
batch_size_padded = model_input.batch_size_padded
|
||||
assert input_tokens is not None
|
||||
assert input_positions is not None
|
||||
assert sampling_metadata is not None
|
||||
assert attn_metadata is not None
|
||||
is_prompt = attn_metadata.is_prompt
|
||||
assert is_prompt is not None
|
||||
batch_size = input_tokens.size(0)
|
||||
seq_len = self._seq_len(attn_metadata)
|
||||
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
||||
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
real_batch_size = model_input.real_batch_size
|
||||
batch_size_padded = model_input.batch_size_padded
|
||||
assert input_tokens is not None
|
||||
assert input_positions is not None
|
||||
assert sampling_metadata is not None
|
||||
assert attn_metadata is not None
|
||||
is_prompt = attn_metadata.is_prompt
|
||||
assert is_prompt is not None
|
||||
batch_size = input_tokens.size(0)
|
||||
seq_len = self._seq_len(attn_metadata)
|
||||
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
||||
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
|
||||
lora_mask: torch.Tensor = None
|
||||
lora_logits_mask: torch.Tensor = None
|
||||
if self.lora_config:
|
||||
assert model_input.lora_ids is not None
|
||||
lora_mask, lora_logits_mask = self.create_lora_mask(
|
||||
input_tokens, model_input.lora_ids,
|
||||
attn_metadata.is_prompt)
|
||||
|
||||
lora_mask: torch.Tensor = None
|
||||
lora_logits_mask: torch.Tensor = None
|
||||
if self.lora_config:
|
||||
assert model_input.lora_ids is not None
|
||||
lora_mask, lora_logits_mask = self.create_lora_mask(
|
||||
input_tokens, model_input.lora_ids, attn_metadata.is_prompt)
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"attn_metadata": self.trim_attn_metadata(attn_metadata),
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"lora_mask": lora_mask,
|
||||
"virtual_engine": model_input.virtual_engine,
|
||||
**(model_input.multi_modal_kwargs or {}),
|
||||
}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
execute_model_kwargs.update(
|
||||
{"bypass_hpu_graphs": not use_graphs})
|
||||
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"attn_metadata": self.trim_attn_metadata(attn_metadata),
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"lora_mask": lora_mask,
|
||||
"virtual_engine": model_input.virtual_engine,
|
||||
**(model_input.multi_modal_kwargs or {}),
|
||||
}
|
||||
if htorch.utils.internal.is_lazy():
|
||||
execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs})
|
||||
htorch.core.mark_step()
|
||||
if self.is_driver_worker:
|
||||
model_event_name = ("model_"
|
||||
f"{'prompt' if is_prompt else 'decode'}_"
|
||||
f"bs{batch_size}_"
|
||||
f"seq{seq_len}_"
|
||||
f"graphs{'T' if use_graphs else 'F'}")
|
||||
else:
|
||||
model_event_name = 'model_executable'
|
||||
if num_steps > 1:
|
||||
# in case of multi-step scheduling
|
||||
# we only want to pythonize in the last step
|
||||
sampling_metadata.skip_sampler_cpu_output = True
|
||||
self.model.model.sampler.include_gpu_probs_tensor = True
|
||||
cache_orig_output_tokens_len: List[Dict] = []
|
||||
|
||||
htorch.core.mark_step()
|
||||
if self.is_driver_worker:
|
||||
model_event_name = ("model_"
|
||||
f"{'prompt' if is_prompt else 'decode'}_"
|
||||
f"bs{batch_size}_"
|
||||
f"seq{seq_len}_"
|
||||
f"graphs{'T' if use_graphs else 'F'}")
|
||||
def try_revert_dummy_output_tokens():
|
||||
if len(cache_orig_output_tokens_len) > 0:
|
||||
# Reuse the original output token ids length
|
||||
for i, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list):
|
||||
for j, data in seq_group_metadata.seq_data.items():
|
||||
orig_output_tokens_len = \
|
||||
cache_orig_output_tokens_len[i][j]
|
||||
data.output_token_ids = \
|
||||
data.output_token_ids[:orig_output_tokens_len]
|
||||
|
||||
for i in range(num_steps):
|
||||
if i != 0 and not self.is_driver_worker:
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if 'early_exit' in broadcast_data and broadcast_data[
|
||||
'early_exit']:
|
||||
return [output] if num_steps == 1 else []
|
||||
execute_model_kwargs.update({
|
||||
"input_ids":
|
||||
broadcast_data["input_ids"],
|
||||
"positions":
|
||||
broadcast_data["positions"],
|
||||
"attn_metadata":
|
||||
self.trim_attn_metadata(
|
||||
broadcast_data["attn_metadata"])
|
||||
})
|
||||
with self.profiler.record_event('internal', model_event_name):
|
||||
hidden_states = self.model.forward(
|
||||
**execute_model_kwargs,
|
||||
selected_token_indices=sampling_metadata.
|
||||
selected_token_indices)
|
||||
|
||||
if self.lora_config:
|
||||
LoraMask.setLoraMask(
|
||||
lora_logits_mask.index_select(
|
||||
0, sampling_metadata.selected_token_indices))
|
||||
|
||||
# Compute the logits.
|
||||
with self.profiler.record_event(
|
||||
'internal',
|
||||
('compute_logits_'
|
||||
f'{"prompt" if is_prompt else "decode"}_bs'
|
||||
f'{batch_size}_'
|
||||
f'seq{seq_len}')):
|
||||
if num_steps == 1:
|
||||
sampling_metadata.selected_token_indices = None
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
htorch.core.mark_step()
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
continue
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
# Sample the next token.
|
||||
with self.profiler.record_event(
|
||||
'internal', ('sample_'
|
||||
f'{"prompt" if is_prompt else "decode"}_'
|
||||
f'bs{batch_size}_'
|
||||
f'seq{seq_len}')):
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
if num_steps > 1:
|
||||
output = output.sampled_token_ids
|
||||
self.cached_step_outputs.append(
|
||||
output.detach().clone())
|
||||
htorch.core.mark_step()
|
||||
if i < num_steps - 1:
|
||||
if i == 0:
|
||||
if model_input.async_callback is not None:
|
||||
ctx = model_input.async_callback.keywords[ # type: ignore
|
||||
"ctx"]
|
||||
seq_group_metadata_list = \
|
||||
ctx.seq_group_metadata_list
|
||||
elif seqs is not None:
|
||||
seq_group_metadata_list = seqs
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"seq_group_metadata_list is uninitialized")
|
||||
for i, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list):
|
||||
# Skip empty steps
|
||||
seq_group_metadata.state.current_step += (
|
||||
num_steps - 2)
|
||||
# Cache the original output token ids
|
||||
cache_orig_output_tokens_len.append({})
|
||||
for j, data in seq_group_metadata.seq_data.items():
|
||||
cache_orig_output_tokens_len[i][j] = \
|
||||
len(data.output_token_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for data in seq_group_metadata.seq_data.values():
|
||||
max_output_len = sampling_metadata.seq_groups[
|
||||
0].sampling_params.max_tokens
|
||||
if len(data.output_token_ids) < max_output_len - 1:
|
||||
# add a place holder for prepare_decode
|
||||
# arbitrary value, this could be any token
|
||||
dummy_token = (540, )
|
||||
data.output_token_ids += (dummy_token)
|
||||
else:
|
||||
broadcast_tensor_dict({'early_exit': True},
|
||||
src=0)
|
||||
if num_steps == 1:
|
||||
return [output]
|
||||
else:
|
||||
try_revert_dummy_output_tokens()
|
||||
return []
|
||||
|
||||
result = self._prepare_decode(seq_group_metadata_list,
|
||||
output=output)
|
||||
execute_model_kwargs.update({
|
||||
"input_ids":
|
||||
result.input_tokens,
|
||||
"positions":
|
||||
result.input_positions,
|
||||
"attn_metadata":
|
||||
self.trim_attn_metadata(result.attn_metadata)
|
||||
})
|
||||
model_kwargs_broadcast_data = {
|
||||
"input_ids": result.input_tokens,
|
||||
"positions": result.input_positions,
|
||||
"attn_metadata": vars(result.attn_metadata)
|
||||
}
|
||||
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
|
||||
else:
|
||||
try_revert_dummy_output_tokens()
|
||||
|
||||
if self.is_driver_worker and self.profiler.enabled:
|
||||
# Stop recording 'execute_model' event
|
||||
self.profiler.end()
|
||||
event_end = self.profiler.get_timestamp_us()
|
||||
counters = self.profiler_counter_helper.get_counter_dict(
|
||||
cache_config=self.cache_config,
|
||||
duration=event_end - self.event_start,
|
||||
seq_len=seq_len,
|
||||
batch_size_padded=batch_size_padded,
|
||||
real_batch_size=real_batch_size,
|
||||
is_prompt=is_prompt)
|
||||
self.profiler.record_counter(self.event_start, counters)
|
||||
if num_steps == 1:
|
||||
return [output] if self.is_driver_worker else []
|
||||
else:
|
||||
return []
|
||||
return output if type(output) is list else [output]
|
||||
|
||||
def _decode_sampler_outputs(self, model_input):
|
||||
use_async_out_proc = model_input.async_callback is not None
|
||||
sampler_outputs = []
|
||||
num_outputs = len(self.cached_step_outputs)
|
||||
for i in range(num_outputs):
|
||||
next_token_ids = self.cached_step_outputs.pop(0)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
sampler_output = self._make_decode_output(
|
||||
next_token_ids, model_input.sampling_metadata.seq_groups)
|
||||
sampler_outputs.append(sampler_output)
|
||||
|
||||
if i < num_outputs - 1 and use_async_out_proc:
|
||||
assert model_input.async_callback is not None
|
||||
ctx = model_input.async_callback.keywords[ # type: ignore
|
||||
"ctx"]
|
||||
ctx.append_output(
|
||||
outputs=[sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=False)
|
||||
model_input.async_callback()
|
||||
|
||||
if use_async_out_proc:
|
||||
return [sampler_outputs[-1]]
|
||||
else:
|
||||
model_event_name = 'model_executable'
|
||||
with self.profiler.record_event('internal', model_event_name):
|
||||
hidden_states = self.model.forward(
|
||||
**execute_model_kwargs,
|
||||
selected_token_indices=sampling_metadata.selected_token_indices
|
||||
)
|
||||
return sampler_outputs
|
||||
|
||||
if self.lora_config:
|
||||
LoraMask.setLoraMask(
|
||||
lora_logits_mask.index_select(
|
||||
0, sampling_metadata.selected_token_indices))
|
||||
|
||||
# Compute the logits.
|
||||
with self.profiler.record_event(
|
||||
'internal', ('compute_logits_'
|
||||
f'{"prompt" if is_prompt else "decode"}_bs'
|
||||
f'{batch_size}_'
|
||||
f'seq{seq_len}')):
|
||||
sampling_metadata.selected_token_indices = None
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
htorch.core.mark_step()
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
with self.profiler.record_event(
|
||||
'internal', ('sample_'
|
||||
f'{"prompt" if is_prompt else "decode"}_'
|
||||
f'bs{batch_size}_'
|
||||
f'seq{seq_len}')):
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
output.outputs = output.outputs[:real_batch_size]
|
||||
htorch.core.mark_step()
|
||||
|
||||
if self.is_driver_worker and self.profiler.enabled:
|
||||
# Stop recording 'execute_model' event
|
||||
self.profiler.end()
|
||||
event_end = self.profiler.get_timestamp_us()
|
||||
counters = self.profiler_counter_helper.get_counter_dict(
|
||||
cache_config=self.cache_config,
|
||||
duration=event_end - self.event_start,
|
||||
seq_len=seq_len,
|
||||
batch_size_padded=batch_size_padded,
|
||||
real_batch_size=real_batch_size,
|
||||
is_prompt=is_prompt)
|
||||
self.profiler.record_counter(self.event_start, counters)
|
||||
return [output]
|
||||
def _make_decode_output(
|
||||
self,
|
||||
next_token_ids: List[List[int]],
|
||||
seq_groups: List[SequenceGroupToSample],
|
||||
) -> SamplerOutput:
|
||||
zero_logprob = Logprob(0.0)
|
||||
sampler_outputs = []
|
||||
batch_idx = 0
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
seq_outputs = []
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[batch_idx][0]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
batch_idx += 1
|
||||
sampler_outputs.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||
return SamplerOutput(sampler_outputs)
|
||||
|
||||
def shutdown_inc(self):
|
||||
can_finalize_inc = False
|
||||
|
||||
122
vllm/worker/multi_step_hpu_worker.py
Normal file
122
vllm/worker/multi_step_hpu_worker.py
Normal file
@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
###############################################################################
|
||||
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.hpu_model_runner import ModelInputForHPU
|
||||
from vllm.worker.hpu_worker import HPUWorker
|
||||
from vllm.worker.worker_base import WorkerInput
|
||||
|
||||
|
||||
class MultiStepHPUWorker(HPUWorker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cached_model_input: Optional[ModelInputForHPU] = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
assert self.is_driver_worker
|
||||
assert execute_model_req.virtual_engine == 0
|
||||
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
is_last_step = execute_model_req.is_last_step
|
||||
|
||||
if is_first_multi_step:
|
||||
# on first step we prepare the worker input and model input normally
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
worker_input = dataclasses.replace(
|
||||
worker_input,
|
||||
num_steps=execute_model_req.num_lookahead_slots + 1)
|
||||
model_input: ModelInputForHPU = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
assert self.cached_model_input is not None
|
||||
model_input = self.cached_model_input
|
||||
worker_input = WorkerInput()
|
||||
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
is_first_multi_step=is_first_multi_step,
|
||||
is_last_step=is_last_step)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
if is_first_multi_step:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
else:
|
||||
broadcast_data = {
|
||||
"is_first_multi_step": is_first_multi_step,
|
||||
"is_last_step": is_last_step,
|
||||
}
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Returning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
||||
execute_model_req)
|
||||
if model_input.is_first_multi_step:
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
else:
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
if len(broadcast_data) == 2:
|
||||
assert self.cached_model_input is not None
|
||||
self.cached_model_input = dataclasses.replace(
|
||||
self.cached_model_input,
|
||||
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
||||
is_last_step=broadcast_data["is_last_step"])
|
||||
empty_worker_input = WorkerInput()
|
||||
return self.cached_model_input, empty_worker_input, {}
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.
|
||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
Loading…
x
Reference in New Issue
Block a user