mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 14:50:47 +08:00
[Hardware][Intel-Gaudi] Regional compilation support (#13213)
This commit is contained in:
parent
76c89fcadd
commit
b91660ddb8
@ -39,7 +39,10 @@ from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
@ -311,10 +314,38 @@ class HpuModelAdapter:
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
enforce_eager = vllm_config.model_config.enforce_eager
|
||||
|
||||
if not htorch.utils.internal.is_lazy() and not enforce_eager:
|
||||
self.model = torch.compile(self.model,
|
||||
backend='hpu_backend',
|
||||
dynamic=False)
|
||||
if os.getenv('VLLM_REGIONAL_COMPILATION',
|
||||
'true').lower() == 'true':
|
||||
self.regional_compilation_layers_list = [
|
||||
RMSNorm, VocabParallelEmbedding
|
||||
]
|
||||
self._regional_compilation(self.model)
|
||||
else:
|
||||
self.model = torch.compile(self.model,
|
||||
backend='hpu_backend',
|
||||
dynamic=False)
|
||||
|
||||
def _regional_compilation(self,
|
||||
module,
|
||||
parent_module=None,
|
||||
module_name=None):
|
||||
if isinstance(module, torch.nn.ModuleList):
|
||||
for children_name, children_module in module.named_children():
|
||||
self._compile_region(module, children_name, children_module)
|
||||
elif any(
|
||||
isinstance(module, layer)
|
||||
for layer in self.regional_compilation_layers_list):
|
||||
self._compile_region(parent_module, module_name, module)
|
||||
else:
|
||||
for children_name, children_module in module.named_children():
|
||||
self._regional_compilation(children_module, module,
|
||||
children_name)
|
||||
|
||||
def _compile_region(self, model, name, module):
|
||||
module = torch.compile(module, backend='hpu_backend', dynamic=False)
|
||||
setattr(model, name, module)
|
||||
|
||||
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
|
||||
dtype):
|
||||
@ -1575,9 +1606,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
list(sorted(self.bucketing_global_state.decode_buckets)))
|
||||
|
||||
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
|
||||
cache_size_limit = len(
|
||||
self.bucketing_global_state.prompt_buckets) + len(
|
||||
self.bucketing_global_state.decode_buckets) + 1
|
||||
cache_size_limit = 1 + 3 * (
|
||||
len(self.bucketing_global_state.prompt_buckets) +
|
||||
len(self.bucketing_global_state.decode_buckets))
|
||||
torch._dynamo.config.cache_size_limit = max(
|
||||
cache_size_limit, torch._dynamo.config.cache_size_limit)
|
||||
# Multiply by 8 to follow the original default ratio between
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user