diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d6eaf84e40f6b..4ac547ae326da 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -39,7 +39,10 @@ from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs) @@ -311,10 +314,38 @@ class HpuModelAdapter: self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype enforce_eager = vllm_config.model_config.enforce_eager + if not htorch.utils.internal.is_lazy() and not enforce_eager: - self.model = torch.compile(self.model, - backend='hpu_backend', - dynamic=False) + if os.getenv('VLLM_REGIONAL_COMPILATION', + 'true').lower() == 'true': + self.regional_compilation_layers_list = [ + RMSNorm, VocabParallelEmbedding + ] + self._regional_compilation(self.model) + else: + self.model = torch.compile(self.model, + backend='hpu_backend', + dynamic=False) + + def _regional_compilation(self, + module, + parent_module=None, + module_name=None): + if isinstance(module, torch.nn.ModuleList): + for children_name, children_module in module.named_children(): + self._compile_region(module, children_name, children_module) + elif any( + isinstance(module, layer) + for layer in self.regional_compilation_layers_list): + self._compile_region(parent_module, module_name, module) + else: + for children_name, children_module in module.named_children(): + self._regional_compilation(children_module, module, + children_name) + + def _compile_region(self, model, name, module): + module = torch.compile(module, backend='hpu_backend', dynamic=False) + setattr(model, name, module) def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): @@ -1575,9 +1606,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): list(sorted(self.bucketing_global_state.decode_buckets))) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = len( - self.bucketing_global_state.prompt_buckets) + len( - self.bucketing_global_state.decode_buckets) + 1 + cache_size_limit = 1 + 3 * ( + len(self.bucketing_global_state.prompt_buckets) + + len(self.bucketing_global_state.decode_buckets)) torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between