mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-22 23:04:34 +08:00
[torch.compile] support encoder based models (#10613)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
7ea3cd7c3e
commit
571841b7fc
@ -62,6 +62,16 @@ test_settings = [
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
# encoder-based embedding model (BERT)
|
||||
TestSetting(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
model_args=["--task", "embedding"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="XFORMERS",
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
# vision language model
|
||||
TestSetting(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
|
||||
@ -5,6 +5,7 @@ from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -92,14 +93,14 @@ class BertPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.layer = nn.ModuleList([
|
||||
BertLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
@ -336,12 +337,8 @@ class BertModel(nn.Module):
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user