diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 4c842b5251105..456b054b2b43a 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -46,15 +46,15 @@ class HpuPlatform(Platform): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: scheduler_config = vllm_config.scheduler_config + parallel_config = vllm_config.parallel_config if scheduler_config.is_multi_step: - raise NotImplementedError( - "Multi-step execution is not implemented for HPU") + parallel_config.worker_cls = \ + "vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker" if vllm_config.speculative_config is not None: raise NotImplementedError( "Speculative decoding is not implemented for HPU") - parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 6b1593eb8235c..7a346b34cef59 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import DeviceConfig, VllmConfig +from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import get_world_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SequenceGroupToSample from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceData, SequenceGroupMetadata, + SequenceOutput) from vllm.utils import (bind_kv_cache, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( @@ -100,7 +103,10 @@ def subtuple(obj: object, if to_override is None: to_override = {} fields = set(to_copy) | set(to_override.keys()) - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if type(obj) is dict: + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: _TYPE_CACHE[typename] = collections.namedtuple(typename, ' '.join(fields)) @@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase): virtual_engine: int = 0 lora_ids: Optional[List[int]] = None async_callback: Optional[Callable] = None + is_first_multi_step: bool = True + is_last_step: bool = True def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase): "batch_size_padded": self.batch_size_padded, "virtual_engine": self.virtual_engine, "lora_ids": self.lora_ids, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self._set_gc_threshold() self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH + # For multi-step scheduling + self.cached_step_outputs: List[torch.Tensor] = [] + def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold # for comprehensive description of gc generations. @@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], + output=None, ) -> PrepareDecodeMetadata: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) + if output is None: + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) seq_len = seq_data.get_len() position = seq_len - 1 @@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] + num_fully_occupied_blocks = position // self.block_size + block_table = block_table[:num_fully_occupied_blocks + 1] + if len(block_table) == 0: block_number = _PAD_BLOCK_ID else: @@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) + if output is None: + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + else: + real_batch_size = len(seq_group_metadata_list) + input_tokens = output[:real_batch_size] + input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) @@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, None, warmup_mode=True) + is_single_step = \ + self.vllm_config.scheduler_config.num_scheduler_steps == 1 + if is_prompt or is_single_step: + self.execute_model(inputs, None, warmup_mode=True) + else: # decode with multi-step + inputs = dataclasses.replace(inputs, + is_first_multi_step=True, + is_last_step=False) + self.execute_model(inputs, + None, + warmup_mode=True, + num_steps=2, + seqs=seqs) + inputs = dataclasses.replace(inputs, + is_first_multi_step=False, + is_last_step=True) + self.execute_model(inputs, + None, + warmup_mode=True, + num_steps=2, + seqs=seqs) torch.hpu.synchronize() if profiler: profiler.step() @@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, warmup_mode=False, + seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError( - "num_steps > 1 is not supported in HPUModelRunner") + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + # not first or last multi-step + return [] + # last multi-step + output = self._decode_sampler_outputs( + model_input) if self.is_driver_worker else [] + torch.hpu.synchronize() + if model_input.is_first_multi_step: + # first multi-step + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + sampling_metadata = model_input.sampling_metadata + real_batch_size = model_input.real_batch_size + batch_size_padded = model_input.batch_size_padded + assert input_tokens is not None + assert input_positions is not None + assert sampling_metadata is not None + assert attn_metadata is not None + is_prompt = attn_metadata.is_prompt + assert is_prompt is not None + batch_size = input_tokens.size(0) + seq_len = self._seq_len(attn_metadata) + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + self._check_config(batch_size, seq_len, is_prompt, warmup_mode) - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - sampling_metadata = model_input.sampling_metadata - real_batch_size = model_input.real_batch_size - batch_size_padded = model_input.batch_size_padded - assert input_tokens is not None - assert input_positions is not None - assert sampling_metadata is not None - assert attn_metadata is not None - is_prompt = attn_metadata.is_prompt - assert is_prompt is not None - batch_size = input_tokens.size(0) - seq_len = self._seq_len(attn_metadata) - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + lora_mask: torch.Tensor = None + lora_logits_mask: torch.Tensor = None + if self.lora_config: + assert model_input.lora_ids is not None + lora_mask, lora_logits_mask = self.create_lora_mask( + input_tokens, model_input.lora_ids, + attn_metadata.is_prompt) - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - if self.lora_config: - assert model_input.lora_ids is not None - lora_mask, lora_logits_mask = self.create_lora_mask( - input_tokens, model_input.lora_ids, attn_metadata.is_prompt) + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "attn_metadata": self.trim_attn_metadata(attn_metadata), + "intermediate_tensors": intermediate_tensors, + "lora_mask": lora_mask, + "virtual_engine": model_input.virtual_engine, + **(model_input.multi_modal_kwargs or {}), + } + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update( + {"bypass_hpu_graphs": not use_graphs}) - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "attn_metadata": self.trim_attn_metadata(attn_metadata), - "intermediate_tensors": intermediate_tensors, - "lora_mask": lora_mask, - "virtual_engine": model_input.virtual_engine, - **(model_input.multi_modal_kwargs or {}), - } - if htorch.utils.internal.is_lazy(): - execute_model_kwargs.update({"bypass_hpu_graphs": not use_graphs}) + htorch.core.mark_step() + if self.is_driver_worker: + model_event_name = ("model_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + else: + model_event_name = 'model_executable' + if num_steps > 1: + # in case of multi-step scheduling + # we only want to pythonize in the last step + sampling_metadata.skip_sampler_cpu_output = True + self.model.model.sampler.include_gpu_probs_tensor = True + cache_orig_output_tokens_len: List[Dict] = [] - htorch.core.mark_step() - if self.is_driver_worker: - model_event_name = ("model_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") + def try_revert_dummy_output_tokens(): + if len(cache_orig_output_tokens_len) > 0: + # Reuse the original output token ids length + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + for j, data in seq_group_metadata.seq_data.items(): + orig_output_tokens_len = \ + cache_orig_output_tokens_len[i][j] + data.output_token_ids = \ + data.output_token_ids[:orig_output_tokens_len] + + for i in range(num_steps): + if i != 0 and not self.is_driver_worker: + broadcast_data = broadcast_tensor_dict(src=0) + if 'early_exit' in broadcast_data and broadcast_data[ + 'early_exit']: + return [output] if num_steps == 1 else [] + execute_model_kwargs.update({ + "input_ids": + broadcast_data["input_ids"], + "positions": + broadcast_data["positions"], + "attn_metadata": + self.trim_attn_metadata( + broadcast_data["attn_metadata"]) + }) + with self.profiler.record_event('internal', model_event_name): + hidden_states = self.model.forward( + **execute_model_kwargs, + selected_token_indices=sampling_metadata. + selected_token_indices) + + if self.lora_config: + LoraMask.setLoraMask( + lora_logits_mask.index_select( + 0, sampling_metadata.selected_token_indices)) + + # Compute the logits. + with self.profiler.record_event( + 'internal', + ('compute_logits_' + f'{"prompt" if is_prompt else "decode"}_bs' + f'{batch_size}_' + f'seq{seq_len}')): + if num_steps == 1: + sampling_metadata.selected_token_indices = None + logits = self.model.compute_logits(hidden_states, + sampling_metadata) + htorch.core.mark_step() + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + continue + + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. + with self.profiler.record_event( + 'internal', ('sample_' + f'{"prompt" if is_prompt else "decode"}_' + f'bs{batch_size}_' + f'seq{seq_len}')): + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + if num_steps > 1: + output = output.sampled_token_ids + self.cached_step_outputs.append( + output.detach().clone()) + htorch.core.mark_step() + if i < num_steps - 1: + if i == 0: + if model_input.async_callback is not None: + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + seq_group_metadata_list = \ + ctx.seq_group_metadata_list + elif seqs is not None: + seq_group_metadata_list = seqs + else: + raise RuntimeError( + "seq_group_metadata_list is uninitialized") + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + # Skip empty steps + seq_group_metadata.state.current_step += ( + num_steps - 2) + # Cache the original output token ids + cache_orig_output_tokens_len.append({}) + for j, data in seq_group_metadata.seq_data.items(): + cache_orig_output_tokens_len[i][j] = \ + len(data.output_token_ids) + for seq_group_metadata in seq_group_metadata_list: + for data in seq_group_metadata.seq_data.values(): + max_output_len = sampling_metadata.seq_groups[ + 0].sampling_params.max_tokens + if len(data.output_token_ids) < max_output_len - 1: + # add a place holder for prepare_decode + # arbitrary value, this could be any token + dummy_token = (540, ) + data.output_token_ids += (dummy_token) + else: + broadcast_tensor_dict({'early_exit': True}, + src=0) + if num_steps == 1: + return [output] + else: + try_revert_dummy_output_tokens() + return [] + + result = self._prepare_decode(seq_group_metadata_list, + output=output) + execute_model_kwargs.update({ + "input_ids": + result.input_tokens, + "positions": + result.input_positions, + "attn_metadata": + self.trim_attn_metadata(result.attn_metadata) + }) + model_kwargs_broadcast_data = { + "input_ids": result.input_tokens, + "positions": result.input_positions, + "attn_metadata": vars(result.attn_metadata) + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + else: + try_revert_dummy_output_tokens() + + if self.is_driver_worker and self.profiler.enabled: + # Stop recording 'execute_model' event + self.profiler.end() + event_end = self.profiler.get_timestamp_us() + counters = self.profiler_counter_helper.get_counter_dict( + cache_config=self.cache_config, + duration=event_end - self.event_start, + seq_len=seq_len, + batch_size_padded=batch_size_padded, + real_batch_size=real_batch_size, + is_prompt=is_prompt) + self.profiler.record_counter(self.event_start, counters) + if num_steps == 1: + return [output] if self.is_driver_worker else [] + else: + return [] + return output if type(output) is list else [output] + + def _decode_sampler_outputs(self, model_input): + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = self._make_decode_output( + next_token_ids, model_input.sampling_metadata.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=False) + model_input.async_callback() + + if use_async_out_proc: + return [sampler_outputs[-1]] else: - model_event_name = 'model_executable' - with self.profiler.record_event('internal', model_event_name): - hidden_states = self.model.forward( - **execute_model_kwargs, - selected_token_indices=sampling_metadata.selected_token_indices - ) + return sampler_outputs - if self.lora_config: - LoraMask.setLoraMask( - lora_logits_mask.index_select( - 0, sampling_metadata.selected_token_indices)) - - # Compute the logits. - with self.profiler.record_event( - 'internal', ('compute_logits_' - f'{"prompt" if is_prompt else "decode"}_bs' - f'{batch_size}_' - f'seq{seq_len}')): - sampling_metadata.selected_token_indices = None - logits = self.model.compute_logits(hidden_states, - sampling_metadata) - htorch.core.mark_step() - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - with self.profiler.record_event( - 'internal', ('sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}')): - output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) - output.outputs = output.outputs[:real_batch_size] - htorch.core.mark_step() - - if self.is_driver_worker and self.profiler.enabled: - # Stop recording 'execute_model' event - self.profiler.end() - event_end = self.profiler.get_timestamp_us() - counters = self.profiler_counter_helper.get_counter_dict( - cache_config=self.cache_config, - duration=event_end - self.event_start, - seq_len=seq_len, - batch_size_padded=batch_size_padded, - real_batch_size=real_batch_size, - is_prompt=is_prompt) - self.profiler.record_counter(self.event_start, counters) - return [output] + def _make_decode_output( + self, + next_token_ids: List[List[int]], + seq_groups: List[SequenceGroupToSample], + ) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx][0] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return SamplerOutput(sampler_outputs) def shutdown_inc(self): can_finalize_inc = False diff --git a/vllm/worker/multi_step_hpu_worker.py b/vllm/worker/multi_step_hpu_worker.py new file mode 100644 index 0000000000000..2c5e2eac75898 --- /dev/null +++ b/vllm/worker/multi_step_hpu_worker.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +############################################################################### +# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company +############################################################################### + +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.hpu_model_runner import ModelInputForHPU +from vllm.worker.hpu_worker import HPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepHPUWorker(HPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForHPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForHPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + # on subsequent steps we reuse the worker input and model input + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Returning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {}