mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 23:12:30 +08:00
[INTEL-HPU][v0] Port delayed sampling to upstream (#16949)
Signed-off-by: Michal Adamczyk <michal.adamczyk@intel.com> Signed-off-by: Chendi Xue <chendi.xue@intel.com> Co-authored-by: Michal Adamczyk <madamczyk@habana.ai>
This commit is contained in:
parent
e1cf90e099
commit
56a735261c
@ -98,6 +98,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
|
||||
VLLM_HPU_USE_DELAYED_SAMPLING: bool = False
|
||||
VLLM_DP_RANK: int = 0
|
||||
VLLM_DP_RANK_LOCAL: int = -1
|
||||
VLLM_DP_SIZE: int = 1
|
||||
@ -650,6 +651,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
|
||||
("1", "true"),
|
||||
|
||||
# Use delayed sampling for HPU to reduce host cpu overhead
|
||||
# between each step.
|
||||
"VLLM_HPU_USE_DELAYED_SAMPLING":
|
||||
lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in
|
||||
("1", "true"),
|
||||
|
||||
# Rank of the process in the data parallel setting
|
||||
"VLLM_DP_RANK":
|
||||
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
|
||||
|
||||
@ -74,6 +74,8 @@ _PAD_BLOCK_ID = 0
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
DUMMY_TOKEN_ID = -1
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances: Dict[type, object] = {}
|
||||
@ -668,6 +670,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
|
||||
# For multi-step scheduling
|
||||
self.cached_step_outputs: List[torch.Tensor] = []
|
||||
# For delayed sampling
|
||||
self.cached_step_inputs: List[
|
||||
ModelInputForHPUWithSamplingMetadata] = []
|
||||
|
||||
def _set_gc_threshold(self) -> None:
|
||||
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
|
||||
@ -771,6 +776,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
msg = f"Loading model weights took in total {m.get_summary_string()}"
|
||||
logger.info(msg)
|
||||
|
||||
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
|
||||
return htorch.hpu.wrap_in_hpu_graph(
|
||||
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
|
||||
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
|
||||
*args, **kwargs)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@ -2020,6 +2031,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
|
||||
return lora_mask, lora_logits_mask
|
||||
|
||||
def _get_seq_ids(self, model_input):
|
||||
return ([
|
||||
sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups
|
||||
])
|
||||
|
||||
def _pad_to_max_num_seqs(self, tensor, value):
|
||||
padding_needed = self.max_num_seqs - tensor.size(0)
|
||||
if padding_needed:
|
||||
padding = torch.full((padding_needed, *tensor.shape[1:]),
|
||||
value,
|
||||
device=tensor.device,
|
||||
dtype=tensor.dtype)
|
||||
tensor = torch.cat([tensor, padding])
|
||||
return tensor
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -2030,6 +2056,37 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
warmup_mode=False,
|
||||
seqs=None,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
|
||||
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
|
||||
assert not (use_delayed_sampling and num_steps != 1), \
|
||||
'Delayed sampling is not compatible with MSS!'
|
||||
assert model_input.input_tokens is not None
|
||||
if use_delayed_sampling and not model_input.is_prompt and \
|
||||
self.is_driver_worker:
|
||||
num_cached = len(self.cached_step_outputs)
|
||||
assert num_cached > 0
|
||||
cur_seq_ids = self._get_seq_ids(model_input)
|
||||
cur_seq_id_pos = {
|
||||
sid: idx
|
||||
for idx, sid in enumerate(cur_seq_ids) if sid >= 0
|
||||
}
|
||||
htorch.core.mark_step()
|
||||
for i in range(num_cached):
|
||||
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
|
||||
target_indices = [
|
||||
cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids
|
||||
]
|
||||
padding = self.cached_step_outputs[i].size(0) - len(
|
||||
target_indices)
|
||||
target_indices.extend([-1] * padding)
|
||||
target_indices = torch.tensor(
|
||||
target_indices,
|
||||
device=model_input.input_tokens.device,
|
||||
dtype=model_input.input_tokens.dtype)
|
||||
model_input.input_tokens.index_copy_(
|
||||
0, target_indices, self.cached_step_outputs[i])
|
||||
htorch.core.mark_step()
|
||||
|
||||
if not model_input.is_first_multi_step:
|
||||
if not model_input.is_last_step:
|
||||
# not first or last multi-step
|
||||
@ -2045,7 +2102,21 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
input_tokens = model_input.input_tokens
|
||||
# Rank!=0 workers has is_prompt==None
|
||||
if use_delayed_sampling and not model_input.is_prompt and \
|
||||
model_input.input_tokens.size(1) == 1:
|
||||
if self.is_driver_worker:
|
||||
model_kwargs_broadcast_data = {
|
||||
"input_tokens": model_input.input_tokens
|
||||
}
|
||||
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
|
||||
input_tokens = model_input.input_tokens
|
||||
|
||||
else:
|
||||
model_kwargs_broadcast_data = broadcast_tensor_dict(src=0)
|
||||
input_tokens = model_kwargs_broadcast_data["input_tokens"]
|
||||
else:
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
@ -2092,7 +2163,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
f"graphs{'T' if use_graphs else 'F'}")
|
||||
else:
|
||||
model_event_name = 'model_executable'
|
||||
if num_steps > 1:
|
||||
if num_steps > 1 or use_delayed_sampling:
|
||||
# in case of multi-step scheduling
|
||||
# we only want to pythonize in the last step
|
||||
sampling_metadata.skip_sampler_cpu_output = True
|
||||
@ -2152,9 +2223,9 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
if not self.is_driver_worker:
|
||||
continue
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
# Sample the next token.
|
||||
if use_delayed_sampling:
|
||||
fake_output = self._delayed_sampler_outputs(model_input)
|
||||
|
||||
with self.profiler.record_event(
|
||||
'internal', ('sample_'
|
||||
f'{"prompt" if is_prompt else "decode"}_'
|
||||
@ -2166,9 +2237,16 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
)
|
||||
if num_steps > 1:
|
||||
output = output.sampled_token_ids
|
||||
self.cached_step_outputs.append(
|
||||
output.detach().clone())
|
||||
self.cached_step_outputs.append(output)
|
||||
if use_delayed_sampling and self.is_driver_worker:
|
||||
self._patch_prev_output()
|
||||
output = self._pad_to_max_num_seqs(
|
||||
output.sampled_token_ids, DUMMY_TOKEN_ID)
|
||||
self.cached_step_outputs.append(output)
|
||||
self.cached_step_inputs.append(model_input)
|
||||
htorch.core.mark_step()
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
if i < num_steps - 1:
|
||||
if i == 0:
|
||||
if model_input.async_callback is not None:
|
||||
@ -2241,11 +2319,30 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
is_prompt=is_prompt)
|
||||
self.profiler.record_counter(self.event_start, counters)
|
||||
if num_steps == 1:
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
if model_input.is_prompt:
|
||||
output.prefill_hidden_states = hidden_states
|
||||
output.hidden_states = hidden_states
|
||||
if use_delayed_sampling:
|
||||
if self.is_driver_worker:
|
||||
return [fake_output]
|
||||
else:
|
||||
return []
|
||||
|
||||
return [output] if self.is_driver_worker else []
|
||||
else:
|
||||
return []
|
||||
return output if type(output) is list else [output]
|
||||
|
||||
def _delayed_sampler_outputs(self, model_input):
|
||||
next_token_ids = [[DUMMY_TOKEN_ID]] * len(
|
||||
model_input.sampling_metadata.seq_groups)
|
||||
sampler_output = self._make_decode_output(
|
||||
next_token_ids, model_input.sampling_metadata.seq_groups)
|
||||
return sampler_output
|
||||
|
||||
def _decode_sampler_outputs(self, model_input):
|
||||
use_async_out_proc = model_input.async_callback is not None
|
||||
sampler_outputs = []
|
||||
@ -2312,3 +2409,32 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown_inc()
|
||||
|
||||
def _patch_prev_output(self):
|
||||
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
|
||||
f'''Inputs and outputs are out of sync!
|
||||
{len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'''
|
||||
if len(self.cached_step_inputs) == 0:
|
||||
return
|
||||
model_input = self.cached_step_inputs.pop(0)
|
||||
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(
|
||||
-1).tolist()
|
||||
ctx = model_input.async_callback.keywords["ctx"] # type: ignore
|
||||
# If there's no output to patch with, which is usually the case when
|
||||
# we're starting a new request after all requests are completed.
|
||||
if len(ctx.output_queue) == 0:
|
||||
return
|
||||
assert len(
|
||||
ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
|
||||
output_data = ctx.output_queue[0]
|
||||
assert len(output_data.outputs) == 1
|
||||
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
|
||||
fake_out.samples[0].output_token = real_out
|
||||
for sg, real_out in zip(output_data.seq_group_metadata_list,
|
||||
delayed_output):
|
||||
assert len(sg.seq_data) == 1
|
||||
seq_data = list(sg.seq_data.values())[0]
|
||||
# This is a hack. Assigning output_token_ids triggers
|
||||
# a cache recomputation and we only need to update the last token
|
||||
seq_data.output_token_ids_array[-1] = real_out
|
||||
seq_data._cached_all_token_ids[-1] = real_out
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user