[INTEL-HPU][v0] Port delayed sampling to upstream (#16949)

Signed-off-by: Michal Adamczyk <michal.adamczyk@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: Michal Adamczyk <madamczyk@habana.ai>
This commit is contained in:
Chendi.Xue 2025-04-22 22:14:11 -05:00 committed by GitHub
parent e1cf90e099
commit 56a735261c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 140 additions and 7 deletions

View File

@ -98,6 +98,7 @@ if TYPE_CHECKING:
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_HPU_USE_DELAYED_SAMPLING: bool = False
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
@ -650,6 +651,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
("1", "true"),
# Use delayed sampling for HPU to reduce host cpu overhead
# between each step.
"VLLM_HPU_USE_DELAYED_SAMPLING":
lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in
("1", "true"),
# Rank of the process in the data parallel setting
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),

View File

@ -74,6 +74,8 @@ _PAD_BLOCK_ID = 0
LORA_WARMUP_RANK = 8
DUMMY_TOKEN_ID = -1
class Singleton(type):
_instances: Dict[type, object] = {}
@ -668,6 +670,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
# For delayed sampling
self.cached_step_inputs: List[
ModelInputForHPUWithSamplingMetadata] = []
def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@ -771,6 +776,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
*args, **kwargs)
def get_model(self) -> nn.Module:
return self.model
@ -2020,6 +2031,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
return lora_mask, lora_logits_mask
def _get_seq_ids(self, model_input):
return ([
sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups
])
def _pad_to_max_num_seqs(self, tensor, value):
padding_needed = self.max_num_seqs - tensor.size(0)
if padding_needed:
padding = torch.full((padding_needed, *tensor.shape[1:]),
value,
device=tensor.device,
dtype=tensor.dtype)
tensor = torch.cat([tensor, padding])
return tensor
@torch.inference_mode()
def execute_model(
self,
@ -2030,6 +2056,37 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
warmup_mode=False,
seqs=None,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
'Delayed sampling is not compatible with MSS!'
assert model_input.input_tokens is not None
if use_delayed_sampling and not model_input.is_prompt and \
self.is_driver_worker:
num_cached = len(self.cached_step_outputs)
assert num_cached > 0
cur_seq_ids = self._get_seq_ids(model_input)
cur_seq_id_pos = {
sid: idx
for idx, sid in enumerate(cur_seq_ids) if sid >= 0
}
htorch.core.mark_step()
for i in range(num_cached):
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
target_indices = [
cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids
]
padding = self.cached_step_outputs[i].size(0) - len(
target_indices)
target_indices.extend([-1] * padding)
target_indices = torch.tensor(
target_indices,
device=model_input.input_tokens.device,
dtype=model_input.input_tokens.dtype)
model_input.input_tokens.index_copy_(
0, target_indices, self.cached_step_outputs[i])
htorch.core.mark_step()
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
@ -2045,7 +2102,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
input_tokens = model_input.input_tokens
# Rank!=0 workers has is_prompt==None
if use_delayed_sampling and not model_input.is_prompt and \
model_input.input_tokens.size(1) == 1:
if self.is_driver_worker:
model_kwargs_broadcast_data = {
"input_tokens": model_input.input_tokens
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
input_tokens = model_input.input_tokens
else:
model_kwargs_broadcast_data = broadcast_tensor_dict(src=0)
input_tokens = model_kwargs_broadcast_data["input_tokens"]
else:
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
sampling_metadata = model_input.sampling_metadata
@ -2092,7 +2163,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
f"graphs{'T' if use_graphs else 'F'}")
else:
model_event_name = 'model_executable'
if num_steps > 1:
if num_steps > 1 or use_delayed_sampling:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
@ -2152,9 +2223,9 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
if not self.is_driver_worker:
continue
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
if use_delayed_sampling:
fake_output = self._delayed_sampler_outputs(model_input)
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
@ -2166,9 +2237,16 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
)
if num_steps > 1:
output = output.sampled_token_ids
self.cached_step_outputs.append(
output.detach().clone())
self.cached_step_outputs.append(output)
if use_delayed_sampling and self.is_driver_worker:
self._patch_prev_output()
output = self._pad_to_max_num_seqs(
output.sampled_token_ids, DUMMY_TOKEN_ID)
self.cached_step_outputs.append(output)
self.cached_step_inputs.append(model_input)
htorch.core.mark_step()
if model_input.async_callback is not None:
model_input.async_callback()
if i < num_steps - 1:
if i == 0:
if model_input.async_callback is not None:
@ -2241,11 +2319,30 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
if use_delayed_sampling:
if self.is_driver_worker:
return [fake_output]
else:
return []
return [output] if self.is_driver_worker else []
else:
return []
return output if type(output) is list else [output]
def _delayed_sampler_outputs(self, model_input):
next_token_ids = [[DUMMY_TOKEN_ID]] * len(
model_input.sampling_metadata.seq_groups)
sampler_output = self._make_decode_output(
next_token_ids, model_input.sampling_metadata.seq_groups)
return sampler_output
def _decode_sampler_outputs(self, model_input):
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
@ -2312,3 +2409,32 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
def __del__(self):
self.shutdown_inc()
def _patch_prev_output(self):
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
f'''Inputs and outputs are out of sync!
{len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'''
if len(self.cached_step_inputs) == 0:
return
model_input = self.cached_step_inputs.pop(0)
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(
-1).tolist()
ctx = model_input.async_callback.keywords["ctx"] # type: ignore
# If there's no output to patch with, which is usually the case when
# we're starting a new request after all requests are completed.
if len(ctx.output_queue) == 0:
return
assert len(
ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
output_data = ctx.output_queue[0]
assert len(output_data.outputs) == 1
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
fake_out.samples[0].output_token = real_out
for sg, real_out in zip(output_data.seq_group_metadata_list,
delayed_output):
assert len(sg.seq_data) == 1
seq_data = list(sg.seq_data.values())[0]
# This is a hack. Assigning output_token_ids triggers
# a cache recomputation and we only need to update the last token
seq_data.output_token_ids_array[-1] = real_out
seq_data._cached_all_token_ids[-1] = real_out