[torch.compile] support encoder based models (#10613)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-24 21:24:33 -08:00 committed by GitHub
parent 7ea3cd7c3e
commit 571841b7fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 10 deletions

View File

@ -62,6 +62,16 @@ test_settings = [
method="encode",
fullgraph=True,
),
# encoder-based embedding model (BERT)
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="XFORMERS",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",

View File

@ -5,6 +5,7 @@ from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -92,14 +93,14 @@ class BertPooler(nn.Module):
return pooled_output
@support_torch_compile
class BertEncoder(nn.Module):
def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer = nn.ModuleList([
BertLayer(config=config,
cache_config=cache_config,
@ -336,12 +337,8 @@ class BertModel(nn.Module):
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None