diff --git a/vllm/envs.py b/vllm/envs.py index 92dcf1555f223..03a8a2b20f02e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -98,6 +98,7 @@ if TYPE_CHECKING: VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_HPU_USE_DELAYED_SAMPLING: bool = False VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 @@ -650,6 +651,12 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in ("1", "true"), + # Use delayed sampling for HPU to reduce host cpu overhead + # between each step. + "VLLM_HPU_USE_DELAYED_SAMPLING": + lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in + ("1", "true"), + # Rank of the process in the data parallel setting "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7a346b34cef59..2d31024b47d0a 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -74,6 +74,8 @@ _PAD_BLOCK_ID = 0 LORA_WARMUP_RANK = 8 +DUMMY_TOKEN_ID = -1 + class Singleton(type): _instances: Dict[type, object] = {} @@ -668,6 +670,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = [] + # For delayed sampling + self.cached_step_inputs: List[ + ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -771,6 +776,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) + def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): + return htorch.hpu.wrap_in_hpu_graph( + HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True + ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( + *args, **kwargs) + def get_model(self) -> nn.Module: return self.model @@ -2020,6 +2031,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): return lora_mask, lora_logits_mask + def _get_seq_ids(self, model_input): + return ([ + sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups + ]) + + def _pad_to_max_num_seqs(self, tensor, value): + padding_needed = self.max_num_seqs - tensor.size(0) + if padding_needed: + padding = torch.full((padding_needed, *tensor.shape[1:]), + value, + device=tensor.device, + dtype=tensor.dtype) + tensor = torch.cat([tensor, padding]) + return tensor + @torch.inference_mode() def execute_model( self, @@ -2030,6 +2056,37 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): warmup_mode=False, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING + use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode + assert not (use_delayed_sampling and num_steps != 1), \ + 'Delayed sampling is not compatible with MSS!' + assert model_input.input_tokens is not None + if use_delayed_sampling and not model_input.is_prompt and \ + self.is_driver_worker: + num_cached = len(self.cached_step_outputs) + assert num_cached > 0 + cur_seq_ids = self._get_seq_ids(model_input) + cur_seq_id_pos = { + sid: idx + for idx, sid in enumerate(cur_seq_ids) if sid >= 0 + } + htorch.core.mark_step() + for i in range(num_cached): + prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) + target_indices = [ + cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids + ] + padding = self.cached_step_outputs[i].size(0) - len( + target_indices) + target_indices.extend([-1] * padding) + target_indices = torch.tensor( + target_indices, + device=model_input.input_tokens.device, + dtype=model_input.input_tokens.dtype) + model_input.input_tokens.index_copy_( + 0, target_indices, self.cached_step_outputs[i]) + htorch.core.mark_step() + if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step @@ -2045,7 +2102,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - input_tokens = model_input.input_tokens + # Rank!=0 workers has is_prompt==None + if use_delayed_sampling and not model_input.is_prompt and \ + model_input.input_tokens.size(1) == 1: + if self.is_driver_worker: + model_kwargs_broadcast_data = { + "input_tokens": model_input.input_tokens + } + broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) + input_tokens = model_input.input_tokens + + else: + model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) + input_tokens = model_kwargs_broadcast_data["input_tokens"] + else: + input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata sampling_metadata = model_input.sampling_metadata @@ -2092,7 +2163,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' - if num_steps > 1: + if num_steps > 1 or use_delayed_sampling: # in case of multi-step scheduling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True @@ -2152,9 +2223,9 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): if not self.is_driver_worker: continue - if model_input.async_callback is not None: - model_input.async_callback() - # Sample the next token. + if use_delayed_sampling: + fake_output = self._delayed_sampler_outputs(model_input) + with self.profiler.record_event( 'internal', ('sample_' f'{"prompt" if is_prompt else "decode"}_' @@ -2166,9 +2237,16 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ) if num_steps > 1: output = output.sampled_token_ids - self.cached_step_outputs.append( - output.detach().clone()) + self.cached_step_outputs.append(output) + if use_delayed_sampling and self.is_driver_worker: + self._patch_prev_output() + output = self._pad_to_max_num_seqs( + output.sampled_token_ids, DUMMY_TOKEN_ID) + self.cached_step_outputs.append(output) + self.cached_step_inputs.append(model_input) htorch.core.mark_step() + if model_input.async_callback is not None: + model_input.async_callback() if i < num_steps - 1: if i == 0: if model_input.async_callback is not None: @@ -2241,11 +2319,30 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states + if use_delayed_sampling: + if self.is_driver_worker: + return [fake_output] + else: + return [] + return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] + def _delayed_sampler_outputs(self, model_input): + next_token_ids = [[DUMMY_TOKEN_ID]] * len( + model_input.sampling_metadata.seq_groups) + sampler_output = self._make_decode_output( + next_token_ids, model_input.sampling_metadata.seq_groups) + return sampler_output + def _decode_sampler_outputs(self, model_input): use_async_out_proc = model_input.async_callback is not None sampler_outputs = [] @@ -2312,3 +2409,32 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): def __del__(self): self.shutdown_inc() + + def _patch_prev_output(self): + assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ + f'''Inputs and outputs are out of sync! + {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}''' + if len(self.cached_step_inputs) == 0: + return + model_input = self.cached_step_inputs.pop(0) + delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( + -1).tolist() + ctx = model_input.async_callback.keywords["ctx"] # type: ignore + # If there's no output to patch with, which is usually the case when + # we're starting a new request after all requests are completed. + if len(ctx.output_queue) == 0: + return + assert len( + ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' + output_data = ctx.output_queue[0] + assert len(output_data.outputs) == 1 + for fake_out, real_out in zip(output_data.outputs[0], delayed_output): + fake_out.samples[0].output_token = real_out + for sg, real_out in zip(output_data.seq_group_metadata_list, + delayed_output): + assert len(sg.seq_data) == 1 + seq_data = list(sg.seq_data.values())[0] + # This is a hack. Assigning output_token_ids triggers + # a cache recomputation and we only need to update the last token + seq_data.output_token_ids_array[-1] = real_out + seq_data._cached_all_token_ids[-1] = real_out