From 9013e24f7b09a19405c6856b88c004afd4e3fc57 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Thu, 24 Oct 2024 01:07:48 +0800 Subject: [PATCH] [torch.compile] Adding torch compile annotations to some models (#9614) --- vllm/model_executor/models/baichuan.py | 2 ++ vllm/model_executor/models/bloom.py | 2 ++ vllm/model_executor/models/commandr.py | 2 ++ vllm/model_executor/models/exaone.py | 2 ++ vllm/model_executor/models/gemma.py | 2 ++ vllm/model_executor/models/gpt2.py | 2 ++ 6 files changed, 12 insertions(+) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 767230aeacc3..f2cfdf8ffd30 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,6 +26,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -250,6 +251,7 @@ class BaiChuanDecoderLayer(nn.Module): return hidden_states, residual +@support_torch_compile class BaiChuanModel(nn.Module): def __init__(self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index b2c9e221690b..77ab7de6165f 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,7 @@ from torch import nn from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -218,6 +219,7 @@ class BloomBlock(nn.Module): return output +@support_torch_compile class BloomModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 578cd2f04861..348e6d20f329 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -28,6 +28,7 @@ from torch import nn from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul @@ -250,6 +251,7 @@ class CohereDecoderLayer(nn.Module): return hidden_states, residual +@support_torch_compile class CohereModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index dfb8fe55d2fb..4126ceb7117d 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -29,6 +29,7 @@ import torch from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -311,6 +312,7 @@ class ExaoneDecoderLayer(nn.Module): return hidden_states, residual +@support_torch_compile class ExaoneModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 91e556db70a0..436bd45d53f3 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,6 +22,7 @@ from torch import nn from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -239,6 +240,7 @@ class GemmaDecoderLayer(nn.Module): return hidden_states, residual +@support_torch_compile class GemmaModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 975502340e5f..3330d8402136 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from torch import nn from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_world_size) @@ -182,6 +183,7 @@ class GPT2Block(nn.Module): return hidden_states +@support_torch_compile class GPT2Model(nn.Module): def __init__(