mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chang Su <chang.s.su@oracle.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
29 lines
766 B
Python
29 lines
766 B
Python
from transformers.models.mllama import configuration_mllama as mllama_hf_config
|
|
|
|
|
|
class MllamaTextConfig(mllama_hf_config.MllamaTextConfig):
|
|
'''
|
|
Use this class to override is_encoder_decoder:
|
|
- transformers regards mllama as is_encoder_decoder=False
|
|
- vllm needs is_encoder_decoder=True to enable cross-attention
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.is_encoder_decoder = True
|
|
|
|
|
|
class MllamaConfig(mllama_hf_config.MllamaConfig):
|
|
|
|
def __init__(
|
|
self,
|
|
text_config=None,
|
|
**kwargs,
|
|
):
|
|
if isinstance(text_config, dict):
|
|
text_config = MllamaTextConfig(**text_config)
|
|
super().__init__(text_config=text_config, **kwargs)
|