[Hardware][Intel-Gaudi] Regional compilation support (#13213)

This commit is contained in:
Kacper Pietkun 2025-02-28 09:51:49 +01:00 committed by GitHub
parent 76c89fcadd
commit b91660ddb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,7 +39,10 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
@ -311,10 +314,38 @@ class HpuModelAdapter:
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
enforce_eager = vllm_config.model_config.enforce_eager
if not htorch.utils.internal.is_lazy() and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').lower() == 'true':
self.regional_compilation_layers_list = [
RMSNorm, VocabParallelEmbedding
]
self._regional_compilation(self.model)
else:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)
def _regional_compilation(self,
module,
parent_module=None,
module_name=None):
if isinstance(module, torch.nn.ModuleList):
for children_name, children_module in module.named_children():
self._compile_region(module, children_name, children_module)
elif any(
isinstance(module, layer)
for layer in self.regional_compilation_layers_list):
self._compile_region(parent_module, module_name, module)
else:
for children_name, children_module in module.named_children():
self._regional_compilation(children_module, module,
children_name)
def _compile_region(self, model, name, module):
module = torch.compile(module, backend='hpu_backend', dynamic=False)
setattr(model, name, module)
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
@ -1575,9 +1606,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
list(sorted(self.bucketing_global_state.decode_buckets)))
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
cache_size_limit = len(
self.bucketing_global_state.prompt_buckets) + len(
self.bucketing_global_state.decode_buckets) + 1
cache_size_limit = 1 + 3 * (
len(self.bucketing_global_state.prompt_buckets) +
len(self.bucketing_global_state.decode_buckets))
torch._dynamo.config.cache_size_limit = max(
cache_size_limit, torch._dynamo.config.cache_size_limit)
# Multiply by 8 to follow the original default ratio between