# SPDX-License-Identifier: Apache-2.0 ############################################################################### # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### import collections import contextlib import dataclasses import functools import gc import itertools import math import os import time from array import array from enum import Enum, IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch import torch.nn as nn import vllm_hpu_extension.environment as environment from vllm_hpu_extension.bucketing.common import get_bucketing_context from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import DeviceConfig, VllmConfig from vllm.distributed import broadcast_tensor_dict from vllm.distributed.parallel_state import get_world_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SequenceGroupToSample from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import (bind_kv_cache, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) _TYPE_CACHE = {} # These values are assumed to be zero in several places. # Use caution when updating them! _PAD_SLOT_ID = 0 _PAD_BLOCK_ID = 0 LORA_WARMUP_RANK = 8 DUMMY_TOKEN_ID = -1 class PhaseType(Enum): PREFILL = 'prefill' PREFIX_PREFILL = 'prefix_prefill' DECODE = 'decode' def subtuple(obj: object, typename: str, to_copy: List[str], to_override: Optional[Dict[str, object]] = None): if obj is None: return None if to_override is None: to_override = {} fields = set(to_copy) | set(to_override.keys()) if type(obj) is dict: values = {key: obj[key] for key in fields if key in obj} else: values = {f: to_override.get(f, getattr(obj, f)) for f in fields} if typename not in _TYPE_CACHE: _TYPE_CACHE[typename] = collections.namedtuple(typename, ' '.join(fields)) return _TYPE_CACHE[typename](**values) def round_up(value: int, k: int): return (value + k - 1) // k * k def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() if world_size <= 1: return value value_t = torch.tensor(value, device='cpu') torch.distributed.all_reduce(value_t, op=op, group=group) return value_t.item() def setup_profiler(): schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) DEVICE = 'hpu' activities = [torch.profiler.ProfilerActivity.CPU] activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == 'hpu' else []) #from habana_frameworks.torch.activity_profiler import DebugActivity #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] profiler = torch.profiler.profile( schedule=schedule, activities=activities, #debug_activities=debug_activities, on_trace_ready=torch.profiler.tensorboard_trace_handler('.', use_gzip=True), record_shapes=False, with_stack=True) return profiler def pad_list(input, k, v): input_len = len(input) target_len = round_up(input_len, k) padding = target_len - input_len return input + [v] * padding def gather_list(input, indices, v): return [input[i] if i is not None else v for i in indices] def flatten(in_list): return list(itertools.chain(*in_list)) def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): slot_mapping = slot_mapping.flatten() indices = torch.div(slot_mapping, block_size, rounding_mode="floor") if is_prompt: indices = indices.unflatten(0, (-1, block_size))[:, 0] offsets = None else: offsets = torch.fmod(slot_mapping, block_size) return indices, offsets def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): if module.__class__.__name__.endswith(suffix): def forward_hook(module, args, output): htorch.core.mark_step() return output module.register_forward_hook(forward_hook) for child_name, child_module in module.named_children(): modify_decoder_layer(child_module) class HpuModelAdapter: def __init__(self, model, vllm_config): self.model = model self.sampler = get_sampler() self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '0').lower() in ['1', 'true'] self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype enforce_eager = vllm_config.model_config.enforce_eager if not htorch.utils.internal.is_lazy() and not enforce_eager: if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').lower() == 'true': self.regional_compilation_layers_list = [ RMSNorm, VocabParallelEmbedding ] self._regional_compilation(self.model) else: self.model = torch.compile(self.model, backend='hpu_backend', dynamic=False) def _regional_compilation(self, module, parent_module=None, module_name=None): if isinstance(module, torch.nn.ModuleList): for children_name, children_module in module.named_children(): self._compile_region(module, children_name, children_module) elif any( isinstance(module, layer) for layer in self.regional_compilation_layers_list): self._compile_region(parent_module, module_name, module) else: for children_name, children_module in module.named_children(): self._regional_compilation(children_module, module, children_name) def _compile_region(self, model, name, module): module = torch.compile(module, backend='hpu_backend', dynamic=False) setattr(model, name, module) def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): if (attn_metadata is None or (self.prefill_use_fusedsdpa \ and attn_metadata.block_list is None) or not attn_metadata.is_prompt): return attn_metadata prefill_metadata = attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor context_lens_t = prefill_metadata.context_lens_tensor query_lens_t = seq_lens_t - context_lens_t block_list = attn_metadata.block_list max_context_len = (block_list.size(-1) // batch_size if block_list is not None else 0) max_context_len = max_context_len * self.block_size past_mask = torch.arange(0, max_context_len, dtype=torch.int32, device=device) past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge( context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand( batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( query_lens_t.unsqueeze(-1)).view( batch_size, 1, 1, seq_len)) causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) mask = causal_mask.logical_or(len_mask) mask = torch.concat((past_mask, mask), dim=-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) if os.environ.get('VLLM_USE_FAKE_HPU', '0') == '0' and htorch.utils.internal.is_lazy(): block_mapping = torch.nn.functional.one_hot(metadata.block_groups, num_classes=batch_size) else: # Unfortunately one_hot on CPU/torch.compile mode/eager mode # doesn't handle out of bounds classes so we need to convert # all negative values to 0 (block_mapping) or bs (block_groups) block_groups = metadata.block_groups.to(torch.long) block_mapping = torch.nn.functional.relu(block_groups) block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) oob_values = block_groups.lt(0) block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) block_groups.masked_fill_(oob_values, batch_size) metadata = metadata._replace(block_groups=block_groups) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: meta = attn_metadata attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, device, dtype) else: meta = attn_metadata attn_metadata = self._set_block_mapping(meta, batch_size, device, dtype) return attn_metadata def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') if 'warmup_mode' in kwargs: kwargs.pop('warmup_mode') virtual_engine = 0 if 'virtual_engine' in kwargs: virtual_engine = kwargs.pop('virtual_engine') input_ids = kwargs['input_ids'] attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'), input_ids.size(0), input_ids.size(1), input_ids.device, self.dtype) LoraMask.setLoraMask(kwargs.pop('lora_mask')) with set_forward_context(attn_metadata, self.vllm_config, virtual_engine): hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) return hidden_states def compute_logits(self, *args, **kwargs): return self.model.compute_logits(*args, **kwargs) def sample(self, *args, **kwargs): return self.sampler(*args, **kwargs) class PreparePromptMetadata(NamedTuple): input_tokens: torch.Tensor input_positions: List[List[int]] attn_metadata: Optional[AttentionMetadata] seq_lens: List[int] query_lens: List[int] lora_index_mapping: List[List[int]] lora_prompt_mapping: List[List[int]] lora_requests: Set[LoRARequest] multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]] slot_mapping: List[List[int]] lora_ids: List[int] @classmethod def empty(cls): return PreparePromptMetadata(input_tokens=[], input_positions=[], attn_metadata=None, seq_lens=[], query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), multi_modal_kwargs=None, slot_mapping=[], lora_ids=[]) class PrepareDecodeMetadata(NamedTuple): input_tokens: torch.Tensor input_positions: List[List[int]] attn_metadata: Optional[AttentionMetadata] lora_index_mapping: List[List[int]] lora_prompt_mapping: List[List[int]] lora_requests: Set[LoRARequest] slot_mapping: List[List[int]] lora_ids: List[int] @classmethod def empty(cls): return PrepareDecodeMetadata(input_tokens=[], input_positions=[], attn_metadata=None, lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), slot_mapping=[], lora_ids=[]) # How batches are constructed. class BatchType(IntEnum): # Every batch is prefill. PREFILL = 0 # Every batch is decode. DECODE = 1 # Batch is a mixture of prefill and decode. MIXED = 2 TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU") @dataclasses.dataclass(frozen=True) class ModelInputForHPU(ModelRunnerInputBase): """ This base class contains metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. Model runners that run additional steps should subclass this method to add additional fields. """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None real_batch_size: Optional[int] = None batch_size_padded: Optional[int] = None virtual_engine: int = 0 lora_ids: Optional[List[int]] = None async_callback: Optional[Callable] = None is_first_multi_step: bool = True is_last_step: bool = True def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "real_batch_size": self.real_batch_size, "batch_size_padded": self.batch_size_padded, "virtual_engine": self.virtual_engine, "lora_ids": self.lora_ids, "is_first_multi_step": self.is_first_multi_step, "is_last_step": self.is_last_step, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls: Type[TModelInputForHPU], tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> TModelInputForHPU: if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) @dataclasses.dataclass(frozen=True) class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU): """ Used by the ModelRunner. """ sampling_metadata: Optional["SamplingMetadata"] = None # Used for speculative decoding. We do not broadcast it because it is only # used by the driver worker. is_prompt: Optional[bool] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, "lora_ids": self.lora_ids, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForHPUWithSamplingMetadata": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) # FIXME(kzawora): this fails for whatever reason - why? if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): """ Helper class for shared methods between GPU model runners. """ _model_input_cls: Type[TModelInputForHPU] def __init__( self, vllm_config: VllmConfig, is_driver_worker: bool = False, return_hidden_states: bool = False, ): ModelRunnerBase.__init__(self, vllm_config=vllm_config) environment.set_model_config(self.model_config) self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states self.sliding_window = (self.model_config.get_sliding_window() if self.model_config is not None else None) self.device_config = (self.device_config if self.device_config is not None else DeviceConfig()) self.device = self.device_config.device self.enforce_eager = self.model_config.enforce_eager self.max_num_seqs = self.scheduler_config.max_num_seqs # NOTE(kzawora): Change that to scheduler_config.max_num_prefill_seqs # once padding-aware scheduling gets merged self.max_num_prefill_seqs = 64 self.max_model_len = self.scheduler_config.max_model_len self.max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens self.block_size = self.cache_config.block_size self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = self.cache_config.cache_dtype self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, ) # Lazy initialization self.lora_manager: LRUCacheWorkerLoRAManager = None self.model: torch.nn.Module = None self.inc_initialized_successfully = False # Profiler stats self.profiler = HabanaHighLevelProfiler() self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None HPUBucketingContext = get_bucketing_context() self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, self.max_num_batched_tokens, False, self.max_model_len) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() if self.vllm_config.cache_config.enable_prefix_caching: os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False") assert os.environ.get( "VLLM_CONTIGUOUS_PA", "").lower() != "true", "Contiguous PA doesn't support APC" self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH # For multi-step scheduling self.cached_step_outputs: List[torch.Tensor] = [] # For delayed sampling self.cached_step_inputs: List[ ModelInputForHPUWithSamplingMetadata] = [] def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold # for comprehensive description of gc generations. # We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority) # to set particular generation threshold or use simpler # VLLM_GC_THR_MULTIPLIER to multiply default values. default_gc_thrs = list(gc.get_threshold()) requested_gc_thrs = [0] * len(default_gc_thrs) for i in range(len(default_gc_thrs)): requested_gc_thrs[i] = int( os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) if requested_gc_thrs == default_gc_thrs: gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', 2)) requested_gc_thrs = [ t * gc_thr_multiplier for t in default_gc_thrs ] gc.set_threshold(*requested_gc_thrs) self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' def load_model(self) -> None: import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc' or \ self.model_config.quantization == 'fp8': htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model(vllm_config=self.vllm_config) msg = ("Pre-loading model weights on " f"{next(self.model.parameters()).device} " f"took {m_getmodel.get_summary_string()}") logger.info(msg) if self.lora_config: assert hasattr(self.model, "embedding_modules" ), "Model does not have embedding_modules" assert hasattr( self.model, "embedding_padding_modules" ), "Model does not have embedding_padding_modules" assert not self.lora_config.bias_enabled, \ "Bias support in LoRA is not enabled in HPU yet." assert not self.lora_config.fully_sharded_loras, \ "Fully sharded LoRAs is not enabled in HPU yet." # Use get_text_config() in case of multimodal models text_config = self.model_config.hf_config.get_text_config() self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules, max_position_embeddings=text_config. max_position_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc': logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: from neural_compressor.torch.quantization import ( FP8Config, convert, prepare) config = FP8Config.from_json_file( os.getenv("QUANT_CONFIG", "")) if config.measure: self.model = prepare(self.model, config) elif config.quantize: self.model = convert(self.model, config) htcore.hpu_initialize(self.model, mark_only_scales_as_const=True) self.inc_initialized_successfully = True logger.info("Preparing model with INC took %s", m_inc.get_summary_string()) else: self.model = self.model.to("hpu") htcore.mark_step() modify_decoder_layer(self.model) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: self.model = _maybe_wrap_in_hpu_graph( self.model, vllm_config=self.vllm_config) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) self.model_memory_usage = m.consumed_device_memory msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): real_batch_size = len(seq_group_metadata_list) batch_size_padded = self.bucketing_ctx.get_padded_batch_size( real_batch_size, is_prompt) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() if batch_size_padding > 0: dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( 0, 0, is_prompt) seq_group_metadata_list.extend(dummy_seq_group_metadata for _ in range(batch_size_padding)) return seq_group_metadata_list, real_batch_size, batch_size_padded def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): return htorch.hpu.wrap_in_hpu_graph( HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( *args, **kwargs) def get_model(self) -> nn.Module: return self.model def _use_graphs(self, batch_size, seq_len, is_prompt): if self.enforce_eager: return False if self.skip_warmup: return True return (batch_size, seq_len, is_prompt) in self.graphed_buckets def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> PreparePromptMetadata: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] lora_index_mapping: List[List[int]] = [] lora_prompt_mapping: List[List[int]] = [] lora_requests: Set[LoRARequest] = set() seq_lens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = [] if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] context_len = seq_data.get_num_computed_tokens() # We should use get_len here because in case of preemption # it contains output tokens. seq_len = min(seq_data.get_len(), context_len + token_chunk_size) prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size if context_len == seq_len \ and self.vllm_config.cache_config.enable_prefix_caching: # Fully cached prompt - compute only last token context_len = context_len - 1 prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] prefix_block_tables.append(block_table) else: # The first prefill. prefix_block_tables.append([]) else: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 # actual prompt lens context_lens.append(context_len) query_lens.append(seq_len - context_len) input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) mm_kwargs = seq_group_metadata.multi_modal_data if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. slot_mapping.append([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: assert context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") start_idx = max(0, seq_len - self.sliding_window) for i in range(context_len, seq_len): if i < start_idx: slot_mapping[-1].append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) max_query_len = max(query_lens) sum_query_len = sum(query_lens) real_num_seqs = len(query_lens) assert max_query_len > 0 max_prompt_len = max( self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), self.block_size) lora_ids: List[int] = [] for seq_group_metadata, context_len in zip(seq_group_metadata_list, context_lens): lora_id = seq_group_metadata.lora_int_id lora_ids.append(lora_id) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) lora_index_mapping += [lora_id] * max_prompt_len lora_prompt_mapping.extend( [lora_id] * (max_prompt_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if any(context_lens): assert not self.scheduler_config.chunked_prefill_enabled # prefix caching max_num_block = max(len(bt) for bt in prefix_block_tables) prefix_block_list = list( itertools.chain.from_iterable( bt if len(bt) == max_num_block else bt + ([_PAD_BLOCK_ID] * (max_num_block - len(bt))) for bt in prefix_block_tables)) pad_len = len(prefix_block_list) prefix_block_list = pad_list(prefix_block_list, pad_len, _PAD_BLOCK_ID) prefix_block_list_tensor = torch.tensor(prefix_block_list, dtype=torch.long, device=self.device) else: prefix_block_list_tensor = None input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, max_len=max_prompt_len, pad=0, dtype=torch.long, device=self.device) slot_mapping = make_tensor_with_pad(slot_mapping, max_len=max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device=self.device) context_lens_tensor = torch.tensor(context_lens, dtype=torch.long, device=self.device) block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, block_indices=block_indices, block_offsets=block_offsets, block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, context_lens_tensor=context_lens_tensor, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= None, # FIXME(kzawora): mutli-modality will not work here enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return PreparePromptMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, slot_mapping=slot_mapping, lora_ids=lora_ids) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], output=None, ) -> PrepareDecodeMetadata: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[List[int]] = [] lora_prompt_mapping: List[List[int]] = [] lora_requests: Set[LoRARequest] = set() if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() lora_ids: List[int] = [] dummy_slots = itertools.cycle( range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) lora_id = seq_group_metadata.lora_int_id lora_ids.append(lora_id) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] if output is None: generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append([position]) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] num_fully_occupied_blocks = position // self.block_size block_table = block_table[:num_fully_occupied_blocks + 1] if len(block_table) == 0: block_number = _PAD_BLOCK_ID else: block_number = block_table[position // self.block_size] if block_number == _PAD_BLOCK_ID: slot = next(dummy_slots) else: block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) lora_index_mapping.append(lora_id) lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) if output is None: input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) else: real_batch_size = len(seq_group_metadata_list) input_tokens = output[:real_batch_size] input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) num_decode_tokens = sum(seq_lens) last_block_usage = [ slot[0] % self.block_size + 1 for slot in slot_mapping ] block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] for bt, lbu in zip(block_tables, last_block_usage) if bt] block_list = flatten(block_tables) block_groups = flatten(block_groups) block_usage = flatten(block_usage) assert len(block_list) == len(block_groups) assert len(block_list) == len(block_usage) padding_fn = None if self.use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( block_bucket_size) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): indices[bid] = i padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) else: block_bucket_size = \ self.bucketing_ctx.get_padded_decode_num_blocks( len(block_list)) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) block_list = padding_fn(block_list, _PAD_BLOCK_ID) block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) block_groups = torch.tensor(block_groups, dtype=torch.int, device=self.device) block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, block_mapping=None, block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, slot_mapping=slot_mapping, lora_ids=lora_ids) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: return self._model_input_cls(), None input_tokens = None input_positions = None lora_mapping = None lora_requests = None multi_modal_kwargs = None batch_type = None seq_lens = None query_lens = None real_batch_size = None batch_size_padded = None self.event_start = self.profiler.get_timestamp_us() is_prompt = seq_group_metadata_list[0].is_prompt base_event_name = 'prompt' if is_prompt else 'decode' self.profiler.start('internal', base_event_name) seq_group_metadata_list, real_batch_size, batch_size_padded = ( self._add_dummy_seq(seq_group_metadata_list, is_prompt)) prefill_reqs = [] decode_reqs = [] for seq_group_meta in seq_group_metadata_list: if seq_group_meta.is_prompt: prefill_reqs.append(seq_group_meta) else: decode_reqs.append(seq_group_meta) # Prepare input tensors. ( input_tokens, input_positions, prefill_attn_metadata, seq_lens, query_lens, lora_index_mapping, lora_prompt_mapping, lora_requests, multi_modal_kwargs, slot_mapping, lora_ids, ) = self._prepare_prompt(prefill_reqs) ( decode_input_tokens, decode_input_positions, decode_attn_metadata, decode_lora_index_mapping, decode_lora_prompt_mapping, decode_lora_requests, decode_slot_mapping, decode_lora_ids, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 num_prefills = len(seq_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) # NOTE(kzawora): Here we diverge from GPU code - we don't # support mixed batches, so we either use decode or prefill # inputs, without coalescing. assert (num_prefills == 0 and num_decode_tokens > 0) or ( num_prefills > 0 and num_decode_tokens == 0), "HPU does not support mixed batches!" if num_decode_tokens > 0: input_tokens = decode_input_tokens input_positions = decode_input_positions slot_mapping = decode_slot_mapping lora_index_mapping = decode_lora_index_mapping lora_prompt_mapping = decode_lora_prompt_mapping lora_requests = decode_lora_requests lora_ids = decode_lora_ids # FIXME: We need to adjust selected_token_indices to accommodate # for padding max_len = input_tokens.size(1) paddings = [max_len - q for q in query_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): if seq_group_metadata.sampling_params.prompt_logprobs is not None \ and seq_group_metadata.is_prompt: paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, dtype=sampling_metadata.selected_token_indices.dtype, device=sampling_metadata.selected_token_indices.device) sampling_metadata.selected_token_indices.add_(paddings) if self.lora_config: lora_mapping = LoRAMapping( **dict(index_mapping=lora_index_mapping, prompt_mapping=lora_prompt_mapping, is_prefill=(num_prefills > 0))) else: lora_mapping = None if (prefill_attn_metadata is not None and decode_attn_metadata is not None): batch_type = BatchType.MIXED raise NotImplementedError("Mixed batch is not supported on HPU") elif prefill_attn_metadata is not None: batch_type = BatchType.PREFILL else: batch_type = BatchType.DECODE metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_kwargs": multi_modal_kwargs, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, "batch_type": batch_type, "seq_lens": seq_lens, "query_lens": query_lens } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) else: assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) attn_metadata = prefill_attn_metadata if \ prefill_attn_metadata is not None else decode_attn_metadata return self._model_input_cls(input_tokens=input_tokens, seq_lens=seq_lens, query_lens=query_lens, input_positions=input_positions, attn_metadata=attn_metadata, lora_requests=lora_requests, lora_mapping=lora_mapping, multi_modal_kwargs=multi_modal_kwargs, real_batch_size=real_batch_size, batch_size_padded=batch_size_padded, lora_ids=lora_ids), \ sampling_metadata def _seq_len(self, attn_metadata): if attn_metadata.num_prefills != 0: return attn_metadata.slot_mapping.size(1) else: return attn_metadata.block_list.numel() def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: # Trimming metadata is required when using HPUGraphs. # Attention metadata is going to be hashed by PT bridge, and # appropriate HPUGraphs will be matched based on all inputs' hash. # Before you put more keys in here, make sure you know their # value type and make sure you know how it's going to be hashed. # You can find that information in input_hash function # in habana_frameworks/torch/hpu/graphs.py. You can also hash # it manually with torch.hpu.graphs.input_hash(attention_metadata) # If you use primitive types here - they will get hashed based # on their value. You *will* get lots of excessive graph captures # (and an OOM eventually) if you decide to put something like # seq_len int here. # If you absolutely need a scalar, put it in a tensor. Tensors # get hashed using their metadata, not their values: # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', 'block_offsets', 'block_groups', ]) return attention_metadata def create_dummy_seq_group_metadata(self, group_id, seq_len, is_prompt, lora_request=None): sampling_params = SamplingParams(temperature=0) num_blocks = math.ceil(seq_len / self.block_size) seq_len = max(seq_len, 1) if is_prompt: input_len = seq_len output_len = 0 block_tables = None else: input_len = seq_len - 1 output_len = 1 block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} prompt_token_ids = [0] * input_len output_token_ids = [1] * output_len prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 seq_data = SequenceData(prompt_token_ids_array) seq_data.output_token_ids = output_token_ids return SequenceGroupMetadata(request_id=str(group_id), is_prompt=(output_len == 0), seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=block_tables, lora_request=lora_request) def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, [kv_caches]) _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) return def warmup_scenario(self, batch_size, seq_len, is_prompt, kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config and is_lora_profile_run: assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_local_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(batch_size) ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_pt_profiler_run else 1 if is_prompt: seqs = [ self.create_dummy_seq_group_metadata( i, seq_len, is_prompt, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None) for i in range(batch_size) ] else: # FIXME: seq_len is actually number of blocks blocks = [seq_len // batch_size for _ in range(batch_size)] blocks[0] += seq_len % batch_size seqs = [ self.create_dummy_seq_group_metadata( i, b * self.block_size - 1, is_prompt, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None) for i, b in enumerate(blocks) ] torch.hpu.synchronize() profiler = None if is_pt_profiler_run and self.is_driver_worker: profiler = setup_profiler() profiler.start() for _ in range(times): inputs = self.prepare_model_input(seqs) is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 if is_prompt or is_single_step: self.execute_model(inputs, None, warmup_mode=True) else: # decode with multi-step inputs = dataclasses.replace(inputs, is_first_multi_step=True, is_last_step=False) self.execute_model(inputs, None, warmup_mode=True, num_steps=2, seqs=seqs) inputs = dataclasses.replace(inputs, is_first_multi_step=False, is_last_step=True) self.execute_model(inputs, None, warmup_mode=True, num_steps=2, seqs=seqs) torch.hpu.synchronize() if profiler: profiler.step() if profiler: profiler.stop() self.profiler.end() gc.collect() def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_adapters() def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) dim = "num_blocks" if phase == "Prompt": dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " f"{dim}:{seq_len} " f"free_mem:{free_mem}") logger.info(msg) def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem, starting_mem=0, total_batch_seq=0.001): total_mem = starting_mem idx = 0 phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' num_candidates = len(buckets) ordering : Union[Callable[[Any], Tuple[Any, Any]], \ Callable[[Any], Tuple[Any, Any, Any]]] if strategy == 'min_tokens': ordering = lambda b: (b[0] * b[1], b[1], b[0]) elif strategy == 'max_bs': ordering = lambda b: (-b[0], b[1]) else: raise NotImplementedError( f'Unsupported graph allocation strategy: {strategy}') buckets = list(sorted(buckets, key=ordering)) captured_all = True for idx, (batch_size, seq_len) in enumerate(buckets): # Graph memory usage is proportional to seq dimension in a batch batch_seq = batch_size * seq_len if is_prompt else batch_size mem_estimate = batch_seq / total_batch_seq * total_mem if mem_estimate >= available_mem: captured_all = False continue graphed_bucket = (batch_size, seq_len, is_prompt) if graphed_bucket in self.graphed_buckets: continue self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem total_mem += used_mem total_batch_seq += batch_seq return total_mem, total_batch_seq, captured_all def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): num_candidates = len(buckets) phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) if num_candidates == 0: num_candidates = 1 msg = (f'{phase} captured:{len(graphed)} ' f'({100 * len(graphed) / num_candidates:.1f}%) ' f'used_mem:{format_bytes(total_mem)} ' f'buckets:{sorted(list(graphed))}') logger.info(msg) @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: max_blocks = kv_caches[0][0].size(0) self.bucketing_ctx.generate_decode_buckets(max_blocks) if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, True) raise AssertionError("Finished profiling") if not htorch.utils.internal.is_lazy() and not self.enforce_eager: cache_size_limit = 1 + 3 * ( len(self.bucketing_ctx.prompt_buckets) + len(self.bucketing_ctx.decode_buckets)) torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between # the cache_size_limit and accumulated_cache_size_limit torch._dynamo.config.accumulated_cache_size_limit = max( cache_size_limit * 8, torch._dynamo.config.accumulated_cache_size_limit) if self.skip_warmup: logger.info("Skipping warmup...") return self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() compile_only_mode_context = functools.partial(bc.env_setting, "PT_COMPILE_ONLY_MODE", True) can_use_compile_only_mode = True try: with compile_only_mode_context(): pass logger.debug("Using PT_COMPILE_ONLY_MODE.") except KeyError: can_use_compile_only_mode = False logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' 'Warmup time will be negatively impacted. ' 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, kv_caches) self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ ("HabanaWorker.determine_num_available_blocks needs " "to be called before warming up the model.") free_mem = HabanaMemoryProfiler.current_free_device_memory() graph_free_mem = free_mem - self.mem_margin graph_free_mem = align_workers(graph_free_mem, torch.distributed.ReduceOp.MIN) prompt_graph_mem_ratio = float( os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) prompt_available_memory = (prompt_graph_mem_ratio * graph_free_mem) decode_available_memory = (graph_free_mem - prompt_available_memory) msg = ( f"Using {format_bytes(graph_free_mem)}" f"/{format_bytes(free_mem)} " "of free device memory for HPUGraphs, " f"{format_bytes(prompt_available_memory)} for prompt and " f"{format_bytes(decode_available_memory)} for decode " f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") logger.info(msg) prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', 'min_tokens') decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( prompt_strategy, self.bucketing_ctx.prompt_buckets, True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( decode_strategy, self.bucketing_ctx.decode_buckets, False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more prompt buckets. if (mem_post_decode + mem_post_prompt < graph_free_mem and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( prompt_strategy, self.bucketing_ctx.prompt_buckets, True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) # Not all decode buckets were captured, but all prompt buckets # were captured and we have some free graph-allocated space # left. Let's try to use it for capturing more decode buckets. if mem_post_decode + mem_post_prompt < graph_free_mem \ and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( decode_strategy, self.bucketing_ctx.decode_buckets, False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary( self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) self.log_graph_warmup_summary( self.bucketing_ctx.decode_buckets, False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() elapsed_time = end_time - start_time msg = ( f"Warmup finished in {elapsed_time:.0f} secs, " f"allocated {format_bytes(end_mem - start_mem)} of device memory") logger.info(msg) self.profiler.end() @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @property def mem_margin(self) -> Optional[int]: return self._mem_margin @mem_margin.setter def mem_margin(self, value): self._mem_margin = value def _maybe_wrap_in_hpu_graph(*args, **kwargs): return htorch.hpu.wrap_in_hpu_graph( HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) class HabanaProfilerCounterHelper: def __init__(self): self.niter = 0 self.average_real_throughput = None self.logged_once = False self.real_seq_lens = [] self.prompt_seq_lens = [] def capture_seq_group_metadata_stats(self, seq_group_metadata_list): self.real_seq_lens = [ len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids) for seq_group_metadata in seq_group_metadata_list for seq_data in seq_group_metadata.seq_data.values() ] self.prompt_seq_lens = [ len(seq_data.prompt_token_ids) for seq_group_metadata in seq_group_metadata_list for seq_data in seq_group_metadata.seq_data.values() ] def get_counter_dict(self, cache_config, duration, seq_len, batch_size_padded, real_batch_size, is_prompt): throughput = batch_size_padded / (duration / 1e6) throughput_effective = real_batch_size / (duration / 1e6) real_max_seq_len = max(self.real_seq_lens) real_num_tokens = sum(self.real_seq_lens) padded_num_tokens = batch_size_padded * seq_len batch_token_utilization = real_num_tokens / padded_num_tokens if self.average_real_throughput is None: self.average_real_throughput = throughput_effective else: # https://www.heikohoffmann.de/htmlthesis/node134.html self.average_real_throughput = self.average_real_throughput + 1 / ( self.niter + 1) * (throughput_effective - self.average_real_throughput) phase = "prompt" if is_prompt else "decode" counters = { f'{phase}_bucket_batch_size': batch_size_padded, f'{phase}_batch_size': real_batch_size, f'{phase}_bucket_seq_len': seq_len, f'{phase}_seq_len': real_max_seq_len, f'{phase}_bucket_gen_throughput': throughput, f'{phase}_real_gen_throughput': throughput_effective, f'{phase}_batch_token_utilization': batch_token_utilization, 'average_real_throughput': self.average_real_throughput, 'engine_iteration': self.niter, } self.niter += 1 if is_prompt: prompt_bucket_in_throughput = (seq_len * batch_size_padded) / ( duration / 1e6) prompt_real_in_throughput = sum( self.prompt_seq_lens) / (duration / 1e6) counters[ f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput # KV cache might not be created yet (e.g. for profiling run) if cache_config.num_gpu_blocks is not None and \ cache_config.num_gpu_blocks != 0: cache_num_blocks_used = [ math.ceil(sl / cache_config.block_size) for sl in self.real_seq_lens ] cache_total_num_blocks_used = sum(cache_num_blocks_used) num_cache_blocks = cache_config.num_gpu_blocks cache_total_num_free_blocks = \ num_cache_blocks - cache_total_num_blocks_used cache_computed_utilization = \ cache_total_num_blocks_used / num_cache_blocks max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) batch_block_utilization = cache_total_num_blocks_used / ( batch_size_padded * max_blocks_per_seq) counters['cache_num_blocks_used'] = cache_total_num_blocks_used counters['cache_num_free_blocks'] = cache_total_num_free_blocks counters['cache_computed_utilization'] = cache_computed_utilization counters[ f'{phase}_batch_block_utilization'] = batch_block_utilization if not self.logged_once: counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks counters[ 'const_gpu_memory_utilization'] = \ cache_config.gpu_memory_utilization counters['const_block_size'] = cache_config.block_size self.logged_once = True return counters def unwrap_model(model): if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): return unwrap_model(model._orig_mod) else: model = list(vars(model)['_modules'].values())[0] modules = list(vars(model)['_modules'].values()) return modules class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( ModelInputForHPUWithSamplingMetadata) def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], ) -> ModelInputForHPUWithSamplingMetadata: return ( ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, )) @torch.inference_mode() def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForHPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. The API assumes seq_group_metadata_list is sorted by prefill -> decode. The result tensors and data structure also batches input in prefill -> decode order. For example, - input_tokens[:num_prefill_tokens] contains prefill tokens. - input_tokens[num_prefill_tokens:] contains decode tokens. If cuda graph is required, this API automatically pads inputs. """ with self.profiler.record_event('internal', 'prepare_input_tensors'): assert seq_group_metadata_list is not None if self.profiler.enabled: self.profiler_counter_helper.capture_seq_group_metadata_stats( seq_group_metadata_list=seq_group_metadata_list) model_input, sampling_metadata = self.prepare_input_tensors( seq_group_metadata_list) assert model_input.attn_metadata is not None is_prompt = model_input.attn_metadata.is_prompt return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, is_prompt=is_prompt, virtual_engine=virtual_engine) def finish_measurements(self): from neural_compressor.torch.quantization import finalize_calibration finalize_calibration(self.model.model) def _num_blocks(self, attn_metadata): if attn_metadata.block_list is None: return 0 return attn_metadata.block_list.numel() def _phase(self, attn_metadata): phase_type: PhaseType is_prompt = attn_metadata.is_prompt is_prefix_prefill = is_prompt and attn_metadata.block_list is not None if is_prompt and is_prefix_prefill: phase_type = PhaseType.PREFIX_PREFILL elif is_prompt and not is_prefix_prefill: phase_type = PhaseType.PREFILL elif not is_prompt: phase_type = PhaseType.DECODE else: raise ValueError("Unrecognized pass type, likely due to malformed " "attention metadata") return phase_type def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode): is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching cfg: Optional[tuple] = None assert cfg is None, "Configs changed between 2D and 3D" if is_prefix_caching: phase = self._phase(attn_metadata) num_blocks = self._num_blocks(attn_metadata) cfg = (batch_size, seq_len, num_blocks, phase) else: phase = 'prompt' if attn_metadata.is_prompt else 'decode' cfg = (batch_size, seq_len, phase) seen = cfg in self.seen_configs self.seen_configs.add(cfg) if not seen and not warmup_mode: logger.warning("Configuration: %s was not warmed-up!", (phase.value, batch_size, seq_len, num_blocks) if is_prefix_caching else (phase, batch_size, seq_len)) def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], is_prompt: bool): ''' This is a helper function to create the mask for lora computations. Lora Mask is needed to ensure we match the correct lora weights for the for the request. For Prompt phase we have lora_mask with shape (batch_size * seq_len, max_loras * max_rank) lora_logits_mask with shape (batch_size, max_loras * max_rank) For Decode phase we have both lora_mask and lora_logits_mask with shape (batch_size, max_loras * max_rank) ''' lora_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None lora_index = 0 if self.lora_config: if is_prompt: lora_mask = torch.zeros( input_tokens.shape[0] * input_tokens.shape[1], (self.lora_config.max_loras) *\ self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) lora_logits_mask = torch.zeros( input_tokens.shape[0], (self.lora_config.max_loras) * self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) ones = torch.ones(input_tokens.shape[1], self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) logit_ones = torch.ones(1, self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) for i in range(len(lora_ids)): if lora_ids[i] == 0: continue lora_index = self.lora_manager._adapter_manager.\ lora_index_to_id.index(lora_ids[i]) start_row = i * input_tokens.shape[1] end_row = start_row + input_tokens.shape[1] start_col = lora_index * self.lora_config.max_lora_rank end_col = start_col + self.lora_config.max_lora_rank lora_mask[start_row:end_row, start_col:end_col] = ones lora_logits_mask[i, start_col:end_col] = logit_ones lora_mask = lora_mask.to('hpu') lora_logits_mask = lora_logits_mask.to('hpu') else: lora_mask = torch.zeros(input_tokens.shape[0], (self.lora_config.max_loras) * self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) ones = torch.ones(1, self.lora_config.max_lora_rank, dtype=self.lora_config.lora_dtype) for i in range(len(lora_ids)): if lora_ids[i] == 0: continue lora_index = self.lora_manager._adapter_manager.\ lora_index_to_id.index(lora_ids[i]) start_pos = lora_index * self.lora_config.max_lora_rank end_pos = start_pos + self.lora_config.max_lora_rank lora_mask[i, start_pos:end_pos] = ones lora_mask = lora_mask.to('hpu') lora_logits_mask = lora_mask return lora_mask, lora_logits_mask def _get_seq_ids(self, model_input): return ([ sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups ]) def _pad_to_max_num_seqs(self, tensor, value): padding_needed = self.max_num_seqs - tensor.size(0) if padding_needed: padding = torch.full((padding_needed, *tensor.shape[1:]), value, device=tensor.device, dtype=tensor.dtype) tensor = torch.cat([tensor, padding]) return tensor @torch.inference_mode() def execute_model( self, model_input: ModelInputForHPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, warmup_mode=False, seqs=None, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode assert not (use_delayed_sampling and num_steps != 1), \ 'Delayed sampling is not compatible with MSS!' assert model_input.input_tokens is not None if use_delayed_sampling and not model_input.is_prompt and \ self.is_driver_worker: num_cached = len(self.cached_step_outputs) assert num_cached > 0 cur_seq_ids = self._get_seq_ids(model_input) cur_seq_id_pos = { sid: idx for idx, sid in enumerate(cur_seq_ids) if sid >= 0 } htorch.core.mark_step() for i in range(num_cached): prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) target_indices = [ cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids ] padding = self.cached_step_outputs[i].size(0) - len( target_indices) target_indices.extend([-1] * padding) target_indices = torch.tensor( target_indices, device=model_input.input_tokens.device, dtype=model_input.input_tokens.dtype) model_input.input_tokens.index_copy_( 0, target_indices, self.cached_step_outputs[i]) htorch.core.mark_step() if not model_input.is_first_multi_step: if not model_input.is_last_step: # not first or last multi-step return [] # last multi-step output = self._decode_sampler_outputs( model_input) if self.is_driver_worker else [] torch.hpu.synchronize() if model_input.is_first_multi_step: # first multi-step if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) # Rank!=0 workers has is_prompt==None if use_delayed_sampling and not model_input.is_prompt and \ model_input.input_tokens.size(1) == 1: if self.is_driver_worker: model_kwargs_broadcast_data = { "input_tokens": model_input.input_tokens } broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) input_tokens = model_input.input_tokens else: model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) input_tokens = model_kwargs_broadcast_data["input_tokens"] else: input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata sampling_metadata = model_input.sampling_metadata real_batch_size = model_input.real_batch_size batch_size_padded = model_input.batch_size_padded assert input_tokens is not None assert input_positions is not None assert sampling_metadata is not None assert attn_metadata is not None is_prompt = attn_metadata.is_prompt assert is_prompt is not None batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None if self.lora_config: assert model_input.lora_ids is not None lora_mask, lora_logits_mask = self.create_lora_mask( input_tokens, model_input.lora_ids, attn_metadata.is_prompt) execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, "attn_metadata": self.trim_attn_metadata(attn_metadata), "intermediate_tensors": intermediate_tensors, "lora_mask": lora_mask, "virtual_engine": model_input.virtual_engine, **(model_input.multi_modal_kwargs or {}), } if htorch.utils.internal.is_lazy(): execute_model_kwargs.update( {"bypass_hpu_graphs": not use_graphs}) htorch.core.mark_step() if self.is_driver_worker: model_event_name = ("model_" f"{'prompt' if is_prompt else 'decode'}_" f"bs{batch_size}_" f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") else: model_event_name = 'model_executable' if num_steps > 1 or use_delayed_sampling: # in case of multi-step scheduling # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True self.model.sampler.include_gpu_probs_tensor = True cache_orig_output_tokens_len: List[Dict] = [] def try_revert_dummy_output_tokens(): if len(cache_orig_output_tokens_len) > 0: # Reuse the original output token ids length for i, seq_group_metadata in enumerate( seq_group_metadata_list): for j, data in seq_group_metadata.seq_data.items(): orig_output_tokens_len = \ cache_orig_output_tokens_len[i][j] data.output_token_ids = \ data.output_token_ids[:orig_output_tokens_len] for i in range(num_steps): if i != 0 and not self.is_driver_worker: broadcast_data = broadcast_tensor_dict(src=0) if 'early_exit' in broadcast_data and broadcast_data[ 'early_exit']: return [output] if num_steps == 1 else [] execute_model_kwargs.update({ "input_ids": broadcast_data["input_ids"], "positions": broadcast_data["positions"], "attn_metadata": self.trim_attn_metadata( broadcast_data["attn_metadata"]) }) with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, selected_token_indices=sampling_metadata. selected_token_indices) if self.lora_config: LoraMask.setLoraMask( lora_logits_mask.index_select( 0, sampling_metadata.selected_token_indices)) # Compute the logits. with self.profiler.record_event( 'internal', ('compute_logits_' f'{"prompt" if is_prompt else "decode"}_bs' f'{batch_size}_' f'seq{seq_len}')): if num_steps == 1: sampling_metadata.selected_token_indices = None logits = self.model.compute_logits(hidden_states, sampling_metadata) htorch.core.mark_step() # Only perform sampling in the driver worker. if not self.is_driver_worker: continue if use_delayed_sampling: fake_output = self._delayed_sampler_outputs(model_input) with self.profiler.record_event( 'internal', ('sample_' f'{"prompt" if is_prompt else "decode"}_' f'bs{batch_size}_' f'seq{seq_len}')): output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) if num_steps > 1: output = output.sampled_token_ids self.cached_step_outputs.append(output) if use_delayed_sampling and self.is_driver_worker: self._patch_prev_output() output = self._pad_to_max_num_seqs( output.sampled_token_ids, DUMMY_TOKEN_ID) self.cached_step_outputs.append(output) self.cached_step_inputs.append(model_input) htorch.core.mark_step() if model_input.async_callback is not None: model_input.async_callback() if i < num_steps - 1: if i == 0: if model_input.async_callback is not None: ctx = model_input.async_callback.keywords[ # type: ignore "ctx"] seq_group_metadata_list = \ ctx.seq_group_metadata_list elif seqs is not None: seq_group_metadata_list = seqs else: raise RuntimeError( "seq_group_metadata_list is uninitialized") for i, seq_group_metadata in enumerate( seq_group_metadata_list): # Skip empty steps seq_group_metadata.state.current_step += ( num_steps - 2) # Cache the original output token ids cache_orig_output_tokens_len.append({}) for j, data in seq_group_metadata.seq_data.items(): cache_orig_output_tokens_len[i][j] = \ len(data.output_token_ids) for seq_group_metadata in seq_group_metadata_list: for data in seq_group_metadata.seq_data.values(): max_output_len = sampling_metadata.seq_groups[ 0].sampling_params.max_tokens if len(data.output_token_ids) < max_output_len - 1: # add a place holder for prepare_decode # arbitrary value, this could be any token dummy_token = (540, ) data.output_token_ids += (dummy_token) else: broadcast_tensor_dict({'early_exit': True}, src=0) if num_steps == 1: return [output] else: try_revert_dummy_output_tokens() return [] result = self._prepare_decode(seq_group_metadata_list, output=output) execute_model_kwargs.update({ "input_ids": result.input_tokens, "positions": result.input_positions, "attn_metadata": self.trim_attn_metadata(result.attn_metadata) }) model_kwargs_broadcast_data = { "input_ids": result.input_tokens, "positions": result.input_positions, "attn_metadata": vars(result.attn_metadata) } broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) else: try_revert_dummy_output_tokens() if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event self.profiler.end() event_end = self.profiler.get_timestamp_us() counters = self.profiler_counter_helper.get_counter_dict( cache_config=self.cache_config, duration=event_end - self.event_start, seq_len=seq_len, batch_size_padded=batch_size_padded, real_batch_size=real_batch_size, is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) if num_steps == 1: if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None if model_input.is_prompt: output.prefill_hidden_states = hidden_states output.hidden_states = hidden_states if use_delayed_sampling: if self.is_driver_worker: return [fake_output] else: return [] return [output] if self.is_driver_worker else [] else: return [] return output if type(output) is list else [output] def _delayed_sampler_outputs(self, model_input): next_token_ids = [[DUMMY_TOKEN_ID]] * len( model_input.sampling_metadata.seq_groups) sampler_output = self._make_decode_output( next_token_ids, model_input.sampling_metadata.seq_groups) return sampler_output def _decode_sampler_outputs(self, model_input): use_async_out_proc = model_input.async_callback is not None sampler_outputs = [] num_outputs = len(self.cached_step_outputs) for i in range(num_outputs): next_token_ids = self.cached_step_outputs.pop(0) next_token_ids = next_token_ids.cpu().tolist() sampler_output = self._make_decode_output( next_token_ids, model_input.sampling_metadata.seq_groups) sampler_outputs.append(sampler_output) if i < num_outputs - 1 and use_async_out_proc: assert model_input.async_callback is not None ctx = model_input.async_callback.keywords[ # type: ignore "ctx"] ctx.append_output( outputs=[sampler_output], seq_group_metadata_list=ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, is_last_step=False, is_first_step_output=False) model_input.async_callback() if use_async_out_proc: return [sampler_outputs[-1]] else: return sampler_outputs def _make_decode_output( self, next_token_ids: List[List[int]], seq_groups: List[SequenceGroupToSample], ) -> SamplerOutput: zero_logprob = Logprob(0.0) sampler_outputs = [] batch_idx = 0 for seq_group in seq_groups: seq_ids = seq_group.seq_ids seq_outputs = [] for seq_id in seq_ids: next_token_id = next_token_ids[batch_idx][0] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) batch_idx += 1 sampler_outputs.append( CompletionSequenceGroupOutput(seq_outputs, None)) return SamplerOutput(sampler_outputs) def shutdown_inc(self): can_finalize_inc = False from contextlib import suppress with suppress(AttributeError): can_finalize_inc = (self.model_config.quantization == 'inc') and \ (self.model.model is not None) and \ self.inc_initialized_successfully and \ not getattr(self, "_is_inc_finalized", False) if can_finalize_inc: from neural_compressor.torch.quantization import ( finalize_calibration) finalize_calibration(self.model.model) self._is_inc_finalized = True def __del__(self): self.shutdown_inc() def _patch_prev_output(self): assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ f'''Inputs and outputs are out of sync! {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}''' if len(self.cached_step_inputs) == 0: return model_input = self.cached_step_inputs.pop(0) delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( -1).tolist() ctx = model_input.async_callback.keywords["ctx"] # type: ignore # If there's no output to patch with, which is usually the case when # we're starting a new request after all requests are completed. if len(ctx.output_queue) == 0: return assert len( ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' output_data = ctx.output_queue[0] assert len(output_data.outputs) == 1 for fake_out, real_out in zip(output_data.outputs[0], delayed_output): fake_out.samples[0].output_token = real_out for sg, real_out in zip(output_data.seq_group_metadata_list, delayed_output): assert len(sg.seq_data) == 1 seq_data = list(sg.seq_data.values())[0] # This is a hack. Assigning output_token_ids triggers # a cache recomputation and we only need to update the last token seq_data.output_token_ids_array[-1] = real_out seq_data._cached_all_token_ids[-1] = real_out